diff --git a/license-header b/license-header
deleted file mode 100644
index 12bf9309e9a..00000000000
--- a/license-header
+++ /dev/null
@@ -1,14 +0,0 @@
-/*
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- =======================================================================
- */
\ No newline at end of file
diff --git a/pom.xml b/pom.xml
index 76504524ae9..7288f3661b3 100644
--- a/pom.xml
+++ b/pom.xml
@@ -342,7 +342,23 @@
- ./license-header
+
+/* Copyright $YEAR The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ =======================================================================
+ */
+
diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
index 4044838de87..92e4cabdbd1 100644
--- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
+++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
@@ -354,10 +354,10 @@ public final class Ops {
public final SparseOps sparse;
- public final TpuOps tpu;
-
public final BitwiseOps bitwise;
+ public final TpuOps tpu;
+
public final MathOps math;
public final AudioOps audio;
@@ -385,8 +385,8 @@ private Ops(Scope scope) {
random = new RandomOps(this);
strings = new StringsOps(this);
sparse = new SparseOps(this);
- tpu = new TpuOps(this);
bitwise = new BitwiseOps(this);
+ tpu = new TpuOps(this);
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
@@ -7884,7 +7884,7 @@ public final Scope scope() {
* Creates an API for building operations in the provided execution environment
*/
public static Ops create(ExecutionEnvironment env) {
- return new Ops(new Scope(env));
+ return new Ops(env.baseScope());
}
/**
@@ -7893,6 +7893,6 @@ public static Ops create(ExecutionEnvironment env) {
*
Invoking this method is equivalent to {@code Ops.create(EagerSession.getDefault())}.
*/
public static Ops create() {
- return new Ops(new Scope(EagerSession.getDefault()));
+ return create(EagerSession.getDefault());
}
}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java
index dad842f7038..c5d67128406 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java
@@ -1,18 +1,18 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ =======================================================================
+ */
package org.tensorflow;
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_ContextOptionsSetAsync;
@@ -29,6 +29,7 @@
import org.tensorflow.internal.c_api.TFE_ContextOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.op.Op;
+import org.tensorflow.op.Scope;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Variable;
@@ -112,7 +113,8 @@ public Options devicePlacementPolicy(DevicePlacementPolicy value) {
* Configures the session based on the data found in the provided configuration.
*
* @param config a config protocol buffer
- * @see config.proto
+ * @see config.proto
*/
public Options config(ConfigProto config) {
this.config = config;
@@ -306,6 +308,11 @@ public void checkInput(Op input) {
}
}
+ @Override
+ public Scope baseScope() {
+ return baseScope;
+ }
+
TFE_Context nativeHandle() {
checkSession();
return nativeHandle;
@@ -314,17 +321,16 @@ TFE_Context nativeHandle() {
/**
* Attach the list of native resources to this eager session scope.
*
- *
When the eager session is closed (i.e. by calling {@link #close()} explicitly or
- * implicitly via try-with-resources), all native resources attached to the session will be
- * released as well, unless so other references are {@link Pointer#retainReference() retaining}
- * them.
+ * When the eager session is closed (i.e. by calling {@link #close()} explicitly or implicitly
+ * via try-with-resources), all native resources attached to the session will be released as well,
+ * unless so other references are {@link Pointer#retainReference() retaining} them.
*
*
Attached resources can still be garbage collected though if their associated {@link Pointer}
* is no longer reachable in Java, independently of their reference count. Therefore, it is
* assumed that these resources are not required by the native library once the Java client no
- * longer needs them.
+ * longer needs them.
*
- * Attaching a resource already attached to this session will have no effect.
+ * Attaching a resource already attached to this session will have no effect.
*
* @param resources resources to attach to the session
*/
@@ -339,14 +345,14 @@ void attach(Pointer... resources) {
* Detach a list of resources from this eager session scope.
*
*
Detached native resources will prevent them to be automatically released when the session is
- * closed.
+ * closed.
*
* Note though that this method will decrement the reference count of each resources being
- * detached, which may automatically released them if that count reaches 0. Therefore,
- * invoking {@link Pointer#retainReference()} prior to this call on any resource that must remain
- * valid after being detached might be required.
+ * detached, which may automatically released them if that count reaches 0. Therefore, invoking
+ * {@link Pointer#retainReference()} prior to this call on any resource that must remain valid
+ * after being detached might be required.
*
- * Detaching a resource that is not attached to this session will have no effect.
+ * Detaching a resource that is not attached to this session will have no effect.
*
* @param resources resources to detach from the session
*/
@@ -362,6 +368,8 @@ void detach(Pointer... resources) {
private final WeakPointerScope nativeResources;
private TFE_Context nativeHandle;
+ private final Scope baseScope = new Scope(this);
+
private EagerSession(Options options) {
this.nativeResources = new WeakPointerScope();
this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config);
@@ -381,7 +389,8 @@ private synchronized void doClose() {
}
}
- private static TFE_Context allocate(boolean async, int devicePlacementPolicy, ConfigProto config) {
+ private static TFE_Context allocate(
+ boolean async, int devicePlacementPolicy, ConfigProto config) {
try (PointerScope scope = new PointerScope()) {
TFE_ContextOptions opts = TFE_ContextOptions.newContextOptions();
TF_Status status = TF_Status.newStatus();
@@ -390,7 +399,7 @@ private static TFE_Context allocate(boolean async, int devicePlacementPolicy, Co
TFE_ContextOptionsSetConfig(opts, configBytes, configBytes.capacity(), status);
status.throwExceptionIfNotOK();
}
- TFE_ContextOptionsSetAsync(opts, (byte)(async ? 1 : 0));
+ TFE_ContextOptionsSetAsync(opts, (byte) (async ? 1 : 0));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, devicePlacementPolicy);
TFE_Context context = TFE_NewContext(opts, status);
status.throwExceptionIfNotOK();
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java
index d5389bcd0ad..a18c7fff38b 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java
@@ -1,25 +1,24 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ =======================================================================
+ */
package org.tensorflow;
import org.tensorflow.op.Op;
+import org.tensorflow.op.Scope;
-/**
- * Defines an environment for creating and executing TensorFlow {@link Operation}s.
- */
+/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */
public interface ExecutionEnvironment {
enum Types {
@@ -49,11 +48,12 @@ default boolean isOpEnabled(String opType) {
}
/**
- * Checks that {@code input} is valid to use as an input in this execution environment. Throws {@link
- * IllegalArgumentException} if not.
+ * Checks that {@code input} is valid to use as an input in this execution environment. Throws
+ * {@link IllegalArgumentException} if not.
*
* @param input The op to check
- * @throws IllegalArgumentException if input can't be used as an input in this execution environment.
+ * @throws IllegalArgumentException if input can't be used as an input in this execution
+ * environment.
*/
void checkInput(Op input);
@@ -71,4 +71,10 @@ default boolean isEager() {
default boolean isGraph() {
return environmentType() == Types.GRAPH;
}
+
+ /**
+ * Get the top level scope for this execution environment. Is cached, which is necessary to
+ * prevent name collisions.
+ */
+ Scope baseScope();
}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
index 7f659b262a6..b69fe89da0a 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
@@ -1,18 +1,18 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ =======================================================================
+ */
package org.tensorflow;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_AddGradientsWithPrefix;
@@ -52,6 +52,7 @@
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
+import org.tensorflow.op.Scope;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Identity;
import org.tensorflow.op.core.NoOp;
@@ -63,7 +64,6 @@
import org.tensorflow.types.TString;
import org.tensorflow.types.family.TType;
-
/**
* A data flow graph representing a TensorFlow computation.
*
@@ -74,18 +74,16 @@
*/
public final class Graph implements ExecutionEnvironment, AutoCloseable {
- /**
- * Create an empty Graph.
- */
+ /** Create an empty Graph. */
public Graph() {
nativeHandle = allocate();
+ this.baseScope = new Scope(this);
}
- /**
- * Create a Graph from an existing handle (takes ownership).
- */
+ /** Create a Graph from an existing handle (takes ownership). */
Graph(TF_Graph nativeHandle) {
this.nativeHandle = nativeHandle;
+ this.baseScope = new Scope(this);
}
Graph(TF_Graph nativeHandle, SaverDef saverDef) {
@@ -138,8 +136,8 @@ public GraphOperation operation(String name) {
}
/**
- * Returns the operation (node in the Graph) with the provided name, or throws {@link IllegalArgumentException} if
- * there isn't one.
+ * Returns the operation (node in the Graph) with the provided name, or throws {@link
+ * IllegalArgumentException} if there isn't one.
*
* @param name name of the operation to look for
* @return operation in the graph with this name
@@ -155,9 +153,9 @@ public GraphOperation operationOrThrow(String name) {
/**
* Returns the output with the provided name, or {@code null} if there is no such output.
- *
Names should be of the
- * format {@code /scope/op}, with an optional index: {@code /scope/op:1}. {@code 0} is used if the index is not
- * specified.
+ *
+ *
Names should be of the format {@code /scope/op}, with an optional index: {@code
+ * /scope/op:1}. {@code 0} is used if the index is not specified.
*
* @param output the output to get
* @return the output with this name, or null if there isn't one
@@ -181,15 +179,17 @@ public Output> output(String output) {
}
return new Output(operation, index);
} catch (NumberFormatException e) {
- throw new IllegalArgumentException("Could not get output for badly formatted output name: \"" + output + "\"", e);
+ throw new IllegalArgumentException(
+ "Could not get output for badly formatted output name: \"" + output + "\"", e);
}
}
/**
- * Returns the output with the provided name, or throws {@link IllegalArgumentException} if there isn't one.
- *
Names should be of the
- * format {@code /scope/op}, with an optional index: {@code /scope/op:1}. {@code 0} is used if the index is not
- * specified.
+ * Returns the output with the provided name, or throws {@link IllegalArgumentException} if there
+ * isn't one.
+ *
+ *
Names should be of the format {@code /scope/op}, with an optional index: {@code
+ * /scope/op:1}. {@code 0} is used if the index is not specified.
*
* @param output the output to get
* @return the output with this name
@@ -220,16 +220,20 @@ private GraphOperation graphOp(Operand> operand) {
}
/**
- * Finds the operations used to produce {@code outputs}, assuming {@code inputs} are provided. Includes control dependencies.
- *
- * Note that this function can easily return ops upstream of inputs as part of the body. Depending on your use, the
- * returned body should probably be filtered for {@code Placeholder}s, at least.
+ * Finds the operations used to produce {@code outputs}, assuming {@code inputs} are provided.
+ * Includes control dependencies.
+ *
+ *
Note that this function can easily return ops upstream of inputs as part of the body.
+ * Depending on your use, the returned body should probably be filtered for {@code Placeholder}s,
+ * at least.
*
- * @param inputs the inputs of the subgraph. Must be from single output ops. May not be null.
- * @param outputs the outputs of the subgraph. May not be null.
- * @return the set of operations needed to calculate outputs from inputs, including outputs and inputs
+ * @param inputs the inputs of the subgraph. Must be from single output ops. May not be null.
+ * @param outputs the outputs of the subgraph. May not be null.
+ * @return the set of operations needed to calculate outputs from inputs, including outputs and
+ * inputs
*/
- public synchronized Set completeSubgraph(Set> inputs, Set> outputs) {
+ public synchronized Set completeSubgraph(
+ Set> inputs, Set> outputs) {
if (inputs == null) {
throw new IllegalArgumentException("Inputs can't be null.");
@@ -245,7 +249,8 @@ public synchronized Set completeSubgraph(Set> inputs,
for (Operand> input : inputs) {
if (input.op().numOutputs() > 1) {
- throw new IllegalStateException("Only ops with one output are supported as subgraph inputs");
+ throw new IllegalStateException(
+ "Only ops with one output are supported as subgraph inputs");
}
GraphOperation op = graphOp(input);
inputOps.add(op);
@@ -277,15 +282,14 @@ public synchronized Set completeSubgraph(Set> inputs,
currents.add(inputOp);
}
}
-
}
return seen;
}
/**
- * Get all ops directly or indirectly required to calculate {@code outputs} (not including {@code outputs}), including
- * control dependencies.
+ * Get all ops directly or indirectly required to calculate {@code outputs} (not including {@code
+ * outputs}), including control dependencies.
*
* @param outputs the starting points of the traversal.
* @return the ops needed to calculate {@code outputs}, not including {@code outputs}
@@ -306,8 +310,8 @@ public Set subgraphToOps(Set outputs) {
}
/**
- * Get all ops that use one of {@code inputs} directly or indirectly (not including {@code inputs}), including control
- * dependencies.
+ * Get all ops that use one of {@code inputs} directly or indirectly (not including {@code
+ * inputs}), including control dependencies.
*
* @param inputs the starting points of the traversal.
* @return the ops that depend on {@code inputs}, not including {@code inputs}
@@ -328,8 +332,8 @@ public synchronized Set subgraphFromOps(Set inpu
}
/**
- * Get all ops directly or indirectly required to calculate {@code outputs} (not including {@code outputs}), including
- * control dependencies.
+ * Get all ops directly or indirectly required to calculate {@code outputs} (not including {@code
+ * outputs}), including control dependencies.
*
* @param outputs the starting points of the traversal.
* @return the ops needed to calculate {@code outputs}, not including {@code outputs}
@@ -339,8 +343,8 @@ public Set subgraphTo(Set> outputs) {
}
/**
- * Get all ops that use one of {@code inputs} directly or indirectly (not including {@code inputs}), including control
- * dependencies.
+ * Get all ops that use one of {@code inputs} directly or indirectly (not including {@code
+ * inputs}), including control dependencies.
*
* @param inputs the starting points of the traversal.
* @return the ops that depend on {@code inputs}, not including {@code inputs}
@@ -363,8 +367,8 @@ public synchronized Set subgraphFrom(Set> inputs) {
* @param type of the Operation (i.e., identifies the computation to be performed)
* @param name to refer to the created Operation in the graph.
* @return an {@link OperationBuilder}, which will add the Operation to the graph when {@link
- * OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked, then some resources may
- * leak.
+ * OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked,
+ * then some resources may leak.
*/
@Override
public GraphOperationBuilder opBuilder(String type, String name) {
@@ -383,18 +387,26 @@ public Types environmentType() {
public void checkInput(Op input) {
if (input.env().isEager()) {
throw new IllegalArgumentException(
- "Input " + input + " was from an eager session, can't use in a graph. Use tf.constantOf(input.asTensor())");
+ "Input "
+ + input
+ + " was from an eager session, can't use in a graph. Use tf.constantOf(input.asTensor())");
}
if (input.env() != this) {
- throw new IllegalArgumentException("Input " + input + " was from a different graph, can't use.");
+ throw new IllegalArgumentException(
+ "Input " + input + " was from a different graph, can't use.");
}
}
+ @Override
+ public Scope baseScope() {
+ return baseScope;
+ }
+
/**
* Import a representation of a TensorFlow graph.
*
- * The representation of the graph, referred to as a {@code GraphDef}, can be
- * generated by {@link #toGraphDef()} and equivalents in other language APIs.
+ *
The representation of the graph, referred to as a {@code GraphDef}, can be generated by
+ * {@link #toGraphDef()} and equivalents in other language APIs.
*
* @param graphDef {@code GraphDef} proto to import
* @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph.
@@ -442,19 +454,18 @@ public synchronized void addInitializer(Op initializer) {
initializers.add(initializer);
}
- /**
- * Returns all initializers added to the graph via {@link #addInitializer(Op)}
- */
+ /** Returns all initializers added to the graph via {@link #addInitializer(Op)} */
public List initializers() {
return Collections.unmodifiableList(initializers);
}
/**
- * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., {@code d(y_1 + y_2
- * + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e.,
+ * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
*
* {@code dx} are used as initial gradients (which represent the symbolic partial derivatives
- * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}.
+ * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of
+ * {@code y}.
*
*
If {@code dx} is null, the implementation will use dx of {@link
* org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}.
@@ -464,8 +475,8 @@ public List initializers() {
*
* If {@code prefix} is null, then one will be chosen automatically.
*
- * @param prefix unique string prefix applied before the names of nodes added to the graph to compute gradients. If
- * null, a default one will be chosen.
+ * @param prefix unique string prefix applied before the names of nodes added to the graph to
+ * compute gradients. If null, a default one will be chosen.
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
@@ -515,8 +526,11 @@ public Output>[] addGradients(String prefix, Output>[] y, Output>[] x, Out
dxIndices);
int ndy = dyHandlesAndIndices.length >> 1;
if (ndy != dy.length) {
- throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
- + " were expected");
+ throw new IllegalStateException(
+ String.valueOf(ndy)
+ + " gradients were added to the graph when "
+ + dy.length
+ + " were expected");
}
for (int i = 0, j = ndy; i < ndy; ++i, ++j) {
GraphOperation op = new GraphOperation(this, (TF_Operation) dyHandlesAndIndices[i]);
@@ -527,23 +541,24 @@ public Output>[] addGradients(String prefix, Output>[] y, Output>[] x, Out
}
/**
- * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., {@code dy/dx_1,
- * dy/dx_2...}
- *
- * This is a simplified version of {@link #addGradients(String, Output[], Output[], Output[])} where {@code y} is a
- * single output, {@code dx} is null and {@code prefix} is null.
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e.,
+ * {@code dy/dx_1, dy/dx_2...}
+ *
+ *
This is a simplified version of {@link #addGradients(String, Output[], Output[], Output[])}
+ * where {@code y} is a single output, {@code dx} is null and {@code prefix} is null.
*
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
public Output>[] addGradients(Output> y, Output>[] x) {
- return addGradients(null, new Output>[]{y}, x, null);
+ return addGradients(null, new Output>[] {y}, x, null);
}
/**
- * Used to instantiate an abstract class which overrides the buildSubgraph method to build a conditional or body
- * subgraph for a while loop. After Java 8, this can alternatively be used to create a lambda for the same purpose.
+ * Used to instantiate an abstract class which overrides the buildSubgraph method to build a
+ * conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to
+ * create a lambda for the same purpose.
*
*
To be used when calling {@link #whileLoop(Output[],
* org.tensorflow.Graph.WhileSubgraphBuilder, org.tensorflow.Graph.WhileSubgraphBuilder, String)}
@@ -558,7 +573,9 @@ public Output>[] addGradients(Output> y, Output>[] x) {
* }
* };
*
+ *
* Example usage (after Java 8):
+ *
*
* WhileSubgraphBuilder bodyGraphBuilder = (bodyGraph, bodyInputs, bodyOutputs) -> { //
* build body subgraph
@@ -657,13 +674,15 @@ public Output>[] whileLoop(
}
/**
- * Return the {@link SaverDef} instance used to save the state of all variables present in this graph.
+ * Return the {@link SaverDef} instance used to save the state of all variables present in this
+ * graph.
*
- * The first time this method is called it builds the {@link SaverDef}. If this graph already contains a
- * "save/restore_all" operation then it is assumed to contain all necessary saving and restoring operations. If that
- * operation does not exist then the graph is mutated to add all the nodes necessary to save and restore the state of
- * the graph. Consequently, any variables that are added to the graph after this call will not be saved nor restored
- * using this {@link SaverDef}.
+ * The first time this method is called it builds the {@link SaverDef}. If this graph already
+ * contains a "save/restore_all" operation then it is assumed to contain all necessary saving and
+ * restoring operations. If that operation does not exist then the graph is mutated to add all the
+ * nodes necessary to save and restore the state of the graph. Consequently, any variables that
+ * are added to the graph after this call will not be saved nor restored using this {@link
+ * SaverDef}.
*
* @return a {@link SaverDef} instance
*/
@@ -678,11 +697,12 @@ synchronized SaverDef saverDef() {
// regenerate SaverDef without mutating. The names mirror
// the python implementation for compatibility.
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
- saverDef = SaverDef.newBuilder()
- .setFilenameTensorName("save/filename")
- .setSaveTensorName("save/control_dependency")
- .setRestoreOpName("save/restore_all")
- .build();
+ saverDef =
+ SaverDef.newBuilder()
+ .setFilenameTensorName("save/filename")
+ .setSaveTensorName("save/control_dependency")
+ .setRestoreOpName("save/restore_all")
+ .build();
}
}
return saverDef;
@@ -692,6 +712,7 @@ synchronized SaverDef saverDef() {
private TF_Graph nativeHandle;
private int refcount = 0;
private SaverDef saverDef;
+ private final Scope baseScope;
private final List initializers = new ArrayList<>();
@@ -757,7 +778,9 @@ private final void advance() {
try {
Object[] nativeReturn = nextOperation(reference.nativeHandle(), this.position);
- if (nativeReturn != null && nativeReturn[0] != null && !((TF_Operation) nativeReturn[0]).isNull()) {
+ if (nativeReturn != null
+ && nativeReturn[0] != null
+ && !((TF_Operation) nativeReturn[0]).isNull()) {
this.operation = new GraphOperation(this.graph, (TF_Operation) nativeReturn[0]);
this.position = (Integer) nativeReturn[1];
}
@@ -863,14 +886,21 @@ private static GraphDef toGraphDef(TF_Graph handle) {
}
}
- static void resolveOutputs(String type, TF_Operation[] srcOps,
- int[] srcIndices, TF_Output dst, int n) {
+ static void resolveOutputs(
+ String type, TF_Operation[] srcOps, int[] srcIndices, TF_Output dst, int n) {
if (srcOps.length != n) {
- throw new IllegalArgumentException("expected " + n + ", got " + srcOps.length + " " + type + " Operations");
+ throw new IllegalArgumentException(
+ "expected " + n + ", got " + srcOps.length + " " + type + " Operations");
}
if (srcIndices.length != n) {
throw new IllegalArgumentException(
- "expected " + n + ", got " + srcIndices.length + " " + type + " Operation output indices");
+ "expected "
+ + n
+ + ", got "
+ + srcIndices.length
+ + " "
+ + type
+ + " Operation output indices");
}
for (int i = 0; i < n; ++i) {
if (srcOps[i] == null || srcOps[i].isNull()) {
@@ -905,7 +935,8 @@ private static Object[] addGradients(
resolveOutputs("x", outputHandles, outputIndices, x, nx);
if (gradInputHandles != null) {
if (gradInputHandles.length != ny) {
- throw new IllegalArgumentException("expected " + ny + ", got " + gradInputHandles.length + " handles");
+ throw new IllegalArgumentException(
+ "expected " + ny + ", got " + gradInputHandles.length + " handles");
}
dx = new TF_Output(ny);
resolveOutputs("dx", gradInputHandles, gradInputIndices, dx, ny);
@@ -961,9 +992,13 @@ private static Object[] whileLoop(
condOutputIndices[0] = condOutputOutput.index();
Object[] condOutputHandlesAndIndices =
- buildSubgraph(condGraphBuilder, params.cond_graph(),
- condInputHandles, condInputIndices,
- condOutputHandles, condOutputIndices);
+ buildSubgraph(
+ condGraphBuilder,
+ params.cond_graph(),
+ condInputHandles,
+ condInputIndices,
+ condOutputHandles,
+ condOutputIndices);
// build body subgraph
TF_Output bodyInputsOutput = params.body_inputs();
@@ -980,22 +1015,28 @@ private static Object[] whileLoop(
}
Object[] bodyOutputHandlesAndIndices =
- buildSubgraph(bodyGraphBuilder, params.body_graph(),
- bodyInputHandles, bodyInputIndices,
- bodyOutputHandles, bodyOutputIndices);
-
- if (condOutputHandlesAndIndices == null ||
- bodyOutputHandlesAndIndices == null) {
+ buildSubgraph(
+ bodyGraphBuilder,
+ params.body_graph(),
+ bodyInputHandles,
+ bodyInputIndices,
+ bodyOutputHandles,
+ bodyOutputIndices);
+
+ if (condOutputHandlesAndIndices == null || bodyOutputHandlesAndIndices == null) {
return null;
}
// set cond_output param to output of the conditional subgraph
- condOutputOutput.oper((TF_Operation) condOutputHandlesAndIndices[0])
+ condOutputOutput
+ .oper((TF_Operation) condOutputHandlesAndIndices[0])
.index((Integer) condOutputHandlesAndIndices[1]);
// set body_outputs param to outputs of the body subgraph
for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
- bodyOutputsOutput.position(i).oper((TF_Operation) bodyOutputHandlesAndIndices[i])
+ bodyOutputsOutput
+ .position(i)
+ .oper((TF_Operation) bodyOutputHandlesAndIndices[i])
.index((Integer) bodyOutputHandlesAndIndices[j]);
}
@@ -1042,20 +1083,12 @@ private static SaverDef addVariableSaver(Graph graph) {
Operand varSlices = tf.zerosLike(varNamesTensor);
Placeholder saveFilename = tf.withName("filename").placeholder(TString.class);
- Save saveVariables = tf.train.save(
- saveFilename,
- varNamesTensor,
- varSlices,
- varOutputs
- );
- Identity id = tf.withControlDependencies(Arrays.asList(saveFilename, saveVariables))
- .withName("control_dependency").identity(saveFilename);
- Restore restoreVariables = tf.train.restore(
- saveFilename,
- varNamesTensor,
- varSlices,
- varTypes
- );
+ Save saveVariables = tf.train.save(saveFilename, varNamesTensor, varSlices, varOutputs);
+ Identity id =
+ tf.withControlDependencies(Arrays.asList(saveFilename, saveVariables))
+ .withName("control_dependency")
+ .identity(saveFilename);
+ Restore restoreVariables = tf.train.restore(saveFilename, varNamesTensor, varSlices, varTypes);
List restoreOps = new ArrayList<>(varOutputs.size());
for (int i = 0; i < varOutputs.size(); ++i) {
restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i)));
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java
index 2e84cac1ac7..903a12f66b2 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java
@@ -1,23 +1,26 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ =======================================================================
+ */
package org.tensorflow.op;
import java.util.HashMap;
import java.util.Map;
+import java.util.regex.Matcher;
import java.util.regex.Pattern;
+import org.tensorflow.ExecutionEnvironment;
+import org.tensorflow.Graph;
/**
* A class to manage scoped (hierarchical) names for operators.
@@ -36,12 +39,12 @@
*/
final class NameScope {
- NameScope withSubScope(String scopeName) {
+ NameScope withSubScope(String scopeName, ExecutionEnvironment env) {
checkPattern(NAME_REGEX, scopeName);
// Override with opName if it exists.
String actualName = (opName != null) ? opName : scopeName;
String newPrefix = fullyQualify(makeUnique(actualName));
- return new NameScope(newPrefix, null, null);
+ return new NameScope(newPrefix, null, null).withUsedFrom(env);
}
NameScope withName(String name) {
@@ -50,6 +53,46 @@ NameScope withName(String name) {
return new NameScope(opPrefix, name, ids);
}
+ private static final Pattern NAME_PATTERN = Pattern.compile("(.+)_(\\d+)", Pattern.DOTALL);
+
+ /** "Import" used names from a graph. Useful when adding to a loaded graph. */
+ private NameScope withUsedFrom(ExecutionEnvironment env) {
+
+ if (env instanceof Graph) {
+ ((Graph) env)
+ .operations()
+ .forEachRemaining(
+ op -> {
+ if (op.name().startsWith(opPrefix != null ? opPrefix : "")) {
+ String name = op.name();
+
+ if (opPrefix != null) {
+ name = name.substring(opPrefix.length() + 1);
+ }
+
+ if (!name.contains("/")) {
+ Matcher matcher = NAME_PATTERN.matcher(name);
+ if (matcher.find()) {
+ String realName = matcher.group(1);
+ int num = Integer.parseInt(matcher.group(2)) + 1;
+
+ if (!(ids.containsKey(realName) && ids.get(realName) > num)) {
+ ids.put(realName, num);
+ }
+ } else {
+ if (!ids.containsKey(name)) {
+ ids.put(name, 1);
+ } else {
+ ids.put(name, ids.get(name) + 1);
+ }
+ }
+ }
+ }
+ });
+ }
+ return this;
+ }
+
String makeOpName(String name) {
checkPattern(NAME_REGEX, name);
// Override with opName if it exists.
@@ -62,9 +105,12 @@ String makeOpName(String name) {
*
* A root-level namescope generates operator names with no components, like {@code Const_72}
* and {@code result}.
+ *
+ * @param env
*/
- NameScope() {
+ NameScope(ExecutionEnvironment env) {
this(null, null, null);
+ withUsedFrom(env);
}
private NameScope(String opPrefix, String opName, Map ids) {
@@ -120,6 +166,13 @@ private String fullyQualify(String name) {
// instance mapped to the next available numeric suffix for it.
private final Map ids;
+ static boolean isValidName(String name) {
+ if (name == null) {
+ return false;
+ }
+ return NAME_REGEX.matcher(name).matches();
+ }
+
private static void checkPattern(Pattern pattern, String name) {
if (name == null) {
throw new IllegalArgumentException("Names cannot be null");
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java
index 85e283d9260..2aef70f6af0 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java
@@ -1,18 +1,18 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ =======================================================================
+ */
package org.tensorflow.op;
import java.util.ArrayList;
@@ -21,7 +21,8 @@
import org.tensorflow.OperationBuilder;
/**
- * Manages groups of related properties when creating Tensorflow Operations, such as a common name prefix.
+ * Manages groups of related properties when creating Tensorflow Operations, such as a common name
+ * prefix.
*
* A {@code Scope} is a container for common properties applied to TensorFlow Ops. Normal user
* code initializes a {@code Scope} and provides it to Operation building classes. For example:
@@ -80,15 +81,16 @@ public final class Scope {
/**
* Create a new top-level scope.
*
+ *
For internal use only, use {@link ExecutionEnvironment#baseScope()} if you need a
+ * base level scope.
+ *
* @param env The execution environment used by the scope.
*/
public Scope(ExecutionEnvironment env) {
- this(env, new NameScope(), new ArrayList<>(), DeviceSpec.newBuilder().build());
+ this(env, new NameScope(env), new ArrayList<>(), DeviceSpec.newBuilder().build());
}
- /**
- * Returns the execution environment used by this scope.
- */
+ /** Returns the execution environment used by this scope. */
public ExecutionEnvironment env() {
return env;
}
@@ -97,7 +99,8 @@ public ExecutionEnvironment env() {
* Returns a new scope where added operations will have the provided name prefix.
*
*
Ops created with this scope will have {@code name/childScopeName/} as the prefix. The actual
- * name will be unique in the returned scope. All other properties are inherited from the current scope.
+ * name will be unique in the returned scope. All other properties are inherited from the current
+ * scope.
*
*
The child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*}
*
@@ -106,7 +109,8 @@ public ExecutionEnvironment env() {
* @throws IllegalArgumentException if the name is invalid
*/
public Scope withSubScope(String childScopeName) {
- return new Scope(env, nameScope.withSubScope(childScopeName), controlDependencies, deviceSpec);
+ return new Scope(
+ env, nameScope.withSubScope(childScopeName, env), controlDependencies, deviceSpec);
}
/**
@@ -126,29 +130,34 @@ public Scope withName(String opName) {
}
/**
- * Returns a new scope where added operations will be prefixed by this scope's op name
- * (set by {@link #withName(String)}), or the given default if it is unset. This is intended to be used for
- * composite ops.
+ * Returns a new scope where added operations will be prefixed by this scope's op name (set by
+ * {@link #withName(String)}), or the given default if it is unset. This is intended to be used
+ * for composite ops.
*
- *
Ops created with this scope will have {@code name/opName/} as the prefix. The actual
- * name will be unique in the returned scope. All other properties are inherited from the current
+ *
Ops created with this scope will have {@code name/opName/} as the prefix. The actual name
+ * will be unique in the returned scope. All other properties are inherited from the current
* scope.
*
- *
The default child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*}
+ *
The default child scope name must match the regular expression {@code
+ * [A-Za-z0-9.][A-Za-z0-9_.\-]*}
*
* @param defaultName name of the sub scope if this scope's name hasn't been set.
* @return a new subscope
* @throws IllegalArgumentException if the name is invalid
*/
- public Scope withNameAsSubScope(String defaultName){
- return new Scope(env, nameScope.withSubScope(nameScope.makeOpName(defaultName)), controlDependencies, deviceSpec);
+ public Scope withNameAsSubScope(String defaultName) {
+ return new Scope(
+ env,
+ nameScope.withSubScope(nameScope.makeOpName(defaultName), env),
+ controlDependencies,
+ deviceSpec);
}
/**
* Return a new scope that uses the provided device specification for an op.
*
- *
Operations created within this scope will place the created operations on the device(s) matching the provided
- * spec.
+ *
Operations created within this scope will place the created operations on the device(s)
+ * matching the provided spec.
*
* @param deviceSpec device specification for an operator in the returned scope
* @return a new Scope that uses opName for operations.
@@ -170,8 +179,8 @@ public Scope withDevice(DeviceSpec deviceSpec) {
* }
*
* Note: if you provide a composite operator building class (i.e, a class that creates a
- * set of related operations by calling other operator building code), the provided name will act as a subscope to all
- * underlying operators.
+ * set of related operations by calling other operator building code), the provided name will act
+ * as a subscope to all underlying operators.
*
* @param defaultName name for the underlying operator.
* @return unique name for the operator.
@@ -181,8 +190,15 @@ public String makeOpName(String defaultName) {
return nameScope.makeOpName(defaultName);
}
+ public static boolean isValidOpName(String name) {
+ return NameScope.isValidName(name);
+ }
+
private Scope(
- ExecutionEnvironment env, NameScope nameScope, Iterable controlDependencies, DeviceSpec deviceSpec) {
+ ExecutionEnvironment env,
+ NameScope nameScope,
+ Iterable controlDependencies,
+ DeviceSpec deviceSpec) {
this.env = env;
this.nameScope = nameScope;
this.controlDependencies = controlDependencies;
@@ -206,8 +222,8 @@ public Scope withControlDependencies(Iterable controls) {
}
/**
- * Applies device specification and adds each Operand in controlDependencies as a control input to the provided
- * builder.
+ * Applies device specification and adds each Operand in controlDependencies as a control input to
+ * the provided builder.
*
* @param builder OperationBuilder to add control inputs and device specification to
*/
@@ -233,9 +249,7 @@ public OperationBuilder applyControlDependencies(OperationBuilder builder) {
private final NameScope nameScope;
private final DeviceSpec deviceSpec;
- /**
- * Returns device string from the scope.
- */
+ /** Returns device string from the scope. */
public String getDeviceString() {
return deviceSpec.toString();
}
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java
index 62881dcee8c..84eabd3da1a 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java
@@ -1,18 +1,18 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2017-2021 The TensorFlow Authors. All Rights Reserved.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ =======================================================================
+ */
package org.tensorflow.op;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -29,6 +29,24 @@
/** Unit tests for {@link org.tensorflow.op.Scope}. */
public class ScopeTest {
+ @Test
+ public void testSeparateOps() {
+ try (Graph g = new Graph()) {
+ Ops tf1 = Ops.create(g);
+ Ops tf2 = Ops.create(g);
+
+ tf1.constant(2);
+ tf1.withName("Constant2").constant(2);
+ tf1.withSubScope("Scope").constant(2);
+ tf1.withSubScope("Scope").withName("Constant4").constant(2);
+
+ tf2.constant(2);
+ tf2.withName("Constant2").constant(2);
+ tf2.withSubScope("Scope").constant(2);
+ tf2.withSubScope("Scope").withName("Constant4").constant(2);
+ }
+ }
+
@Test
public void basicNames() {
try (Graph g = new Graph()) {
@@ -168,9 +186,9 @@ public void composite() {
// assertNotNull(g.operation("variance/zero"));
// Verify correct results as well.
- TInt32 result = (TInt32)sess.runner().fetch(var1.output()).run().get(0);
+ TInt32 result = (TInt32) sess.runner().fetch(var1.output()).run().get(0);
assertEquals(21704, result.getInt());
- result = (TInt32)sess.runner().fetch(var2.output()).run().get(0);
+ result = (TInt32) sess.runner().fetch(var2.output()).run().get(0);
assertEquals(21704, result.getInt());
}
}
diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java
index bea817e9011..1b1d5cb0fb3 100644
--- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java
+++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java
@@ -1,19 +1,18 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ =======================================================================
+ */
package org.tensorflow.processor.operator;
import com.github.javaparser.ast.comments.JavadocComment;
@@ -156,26 +155,28 @@ public Set getSupportedAnnotationTypes() {
}
private static class OpsSpec {
- private static final Comparator PARAMETER_SPEC_COMPARATOR = (o1, o2) -> {
- if (o1.parameters.size() > o2.parameters.size()) {
- return 1;
- }
- if (o1.parameters.size() < o2.parameters.size()) {
- return -1;
- }
- List firstParams = o1.parameters;
- List secondParams = o2.parameters;
- for (int i = 0; i < firstParams.size(); i++) {
- ParameterSpec first = firstParams.get(i);
- ParameterSpec second = secondParams.get(i);
- int compare = first.name.compareTo(second.name);
- if (compare != 0) {
- return compare;
- }
- }
- return 0;
- };
- private static final Comparator METHOD_SPEC_COMPARATOR = Comparator.comparing((MethodSpec m) -> m.name).thenComparing(PARAMETER_SPEC_COMPARATOR);
+ private static final Comparator PARAMETER_SPEC_COMPARATOR =
+ (o1, o2) -> {
+ if (o1.parameters.size() > o2.parameters.size()) {
+ return 1;
+ }
+ if (o1.parameters.size() < o2.parameters.size()) {
+ return -1;
+ }
+ List firstParams = o1.parameters;
+ List secondParams = o2.parameters;
+ for (int i = 0; i < firstParams.size(); i++) {
+ ParameterSpec first = firstParams.get(i);
+ ParameterSpec second = secondParams.get(i);
+ int compare = first.name.compareTo(second.name);
+ if (compare != 0) {
+ return compare;
+ }
+ }
+ return 0;
+ };
+ private static final Comparator METHOD_SPEC_COMPARATOR =
+ Comparator.comparing((MethodSpec m) -> m.name).thenComparing(PARAMETER_SPEC_COMPARATOR);
final String groupName;
final String fieldName;
@@ -183,7 +184,8 @@ private static class OpsSpec {
final List methods;
final List subGroups = new ArrayList<>();
- OpsSpec(String groupName, String fieldName, ClassName className, Collection methods) {
+ OpsSpec(
+ String groupName, String fieldName, ClassName className, Collection methods) {
this.groupName = groupName;
this.fieldName = fieldName;
this.className = className;
@@ -227,11 +229,11 @@ private void error(Element e, String message, Object... args) {
private void write(TypeSpec spec) {
try {
JavaFile.builder("org.tensorflow.op", spec)
- .addFileComment(LICENSE)
- .addFileComment("\nThis class has been generated, DO NOT EDIT!\n")
- .skipJavaLangImports(true)
- .build()
- .writeTo(filer);
+ .addFileComment(LICENSE)
+ .addFileComment("\nThis class has been generated, DO NOT EDIT!\n")
+ .skipJavaLangImports(true)
+ .build()
+ .writeTo(filer);
} catch (IOException e) {
throw new AssertionError(e);
}
@@ -262,7 +264,7 @@ private boolean collectOpsMethods(
result = false;
continue;
}
- collectOpMethods(groupedMethods, (TypeElement)e, annotation);
+ collectOpMethods(groupedMethods, (TypeElement) e, annotation);
}
return result;
}
@@ -281,7 +283,8 @@ private void collectOpMethods(
String opGroup = getAnnotationElementValueAsString("group", operatorAnnot);
String opName = getAnnotationElementValueAsString("name", operatorAnnot);
if (Strings.isNullOrEmpty(opName)) {
- opName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, ClassName.get(opClass).simpleName());
+ opName =
+ CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, ClassName.get(opClass).simpleName());
}
// Build an endpoint for each method annotated with @Endpoint, which takes in parameter a scope
// and, optionally, a list of arguments
@@ -293,11 +296,17 @@ private void collectOpMethods(
throw new IllegalArgumentException(
"Endpoint " + opMethod + " of class " + opClass + " must be static and public");
}
- if (opMethod.getParameters().isEmpty() ||
- !((TypeElement)types.asElement(opMethod.getParameters().get(0).asType())).getQualifiedName()
+ if (opMethod.getParameters().isEmpty()
+ || !((TypeElement) types.asElement(opMethod.getParameters().get(0).asType()))
+ .getQualifiedName()
.equals(elements.getName(Names.Scope.toString()))) {
throw new IllegalArgumentException(
- "Endpoint " + opMethod + " of class " + opClass + " must take an instance of " + Names.Scope
+ "Endpoint "
+ + opMethod
+ + " of class "
+ + opClass
+ + " must take an instance of "
+ + Names.Scope
+ " as its first parameter");
}
String endpointGroup = getAnnotationElementValueAsString("group", endpointAnnot);
@@ -311,15 +320,19 @@ private void collectOpMethods(
boolean describeByClass =
getAnnotationElementValueAsBoolean("describeByClass", endpointAnnot, false);
boolean deprecated = opMethod.getAnnotation(Deprecated.class) != null || opClassDeprecated;
- MethodSpec method = buildOpMethod(endpointName, opClass, opMethod, describeByClass, deprecated);
+ MethodSpec method =
+ buildOpMethod(endpointName, opClass, opMethod, describeByClass, deprecated);
groupedMethods.put(endpointGroup, method);
}
}
}
private MethodSpec buildOpMethod(
- String methodName, TypeElement opClass, ExecutableElement endpointMethod,
- boolean describeByClass, boolean deprecated) {
+ String methodName,
+ TypeElement opClass,
+ ExecutableElement endpointMethod,
+ boolean describeByClass,
+ boolean deprecated) {
MethodSpec.Builder builder =
MethodSpec.methodBuilder(methodName)
.addModifiers(Modifier.PUBLIC)
@@ -341,9 +354,7 @@ private MethodSpec buildOpMethod(
if (!NoType.class.isAssignableFrom(endpointMethod.getReturnType().getClass())) {
call.append("return ");
}
- call.append("$T.")
- .append(endpointMethod.getSimpleName())
- .append("(scope");
+ call.append("$T.").append(endpointMethod.getSimpleName()).append("(scope");
boolean first = true;
for (VariableElement param : endpointMethod.getParameters()) {
ParameterSpec p = ParameterSpec.get(param);
@@ -374,50 +385,68 @@ private String buildOpMethodJavadoc(
// Copy all endpoint method tags to the description, except for the `scope` parameter which
// will be inferred by the Ops class
- methodJavadoc.getBlockTags().forEach(t -> {
- if (!(t.getTagName().equals("param") && t.getName().map(s -> s.equals("scope")).orElse(false))) {
- javadoc.addBlockTag(t);
- }
- });
+ methodJavadoc
+ .getBlockTags()
+ .forEach(
+ t -> {
+ if (!(t.getTagName().equals("param")
+ && t.getName().map(s -> s.equals("scope")).orElse(false))) {
+ javadoc.addBlockTag(t);
+ }
+ });
return javadoc.toText();
}
- private static Collection collectGroupOps(OpsSpec ops, Multimap groupedMethods) {
+ private static Collection collectGroupOps(
+ OpsSpec ops, Multimap groupedMethods) {
Map groups = new HashMap<>();
- // The `group` label added in the `@Operator` annotation has the same syntax as a package name, which (in most
- // case) consists of a simple label but could also be a deeper tree, like `linalg.sparse`. In this case,
- // the `LinalgSparseOps` group should be added as the `sparse` field of the `LinalgOps` group, and the latter
+ // The `group` label added in the `@Operator` annotation has the same syntax as a package name,
+ // which (in most
+ // case) consists of a simple label but could also be a deeper tree, like `linalg.sparse`. In
+ // this case,
+ // the `LinalgSparseOps` group should be added as the `sparse` field of the `LinalgOps` group,
+ // and the latter
// should be added as the `linalg` field of the `Ops` root class.
- groupedMethods.keys().forEach(group -> {
- OpsSpec parentClass = ops;
- int startPos = 0;
- do {
- int delimiterPos = group.indexOf('.', startPos);
- String groupName = delimiterPos < 0 ? group : group.substring(0, delimiterPos);
- OpsSpec groupOps = groups.get(groupName);
-
- // Create spec for this group if we have not encountered it yet in our iteration
- if (groupOps == null) {
- String fieldName = delimiterPos < 0 ?
- group.substring(startPos) : group.substring(startPos, delimiterPos);
- ClassName className = ClassName.get("org.tensorflow.op",
- CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, groupName.replace('.', '_')) + "Ops");
- groupOps = new OpsSpec(groupName, fieldName, className, groupedMethods.get(groupName));
- parentClass.subGroups.add(groupOps);
- groups.put(groupName, groupOps);
- }
- parentClass = groupOps;
- startPos = delimiterPos + 1;
- } while (startPos > 0);
- });
+ groupedMethods
+ .keys()
+ .forEach(
+ group -> {
+ OpsSpec parentClass = ops;
+ int startPos = 0;
+ do {
+ int delimiterPos = group.indexOf('.', startPos);
+ String groupName = delimiterPos < 0 ? group : group.substring(0, delimiterPos);
+ OpsSpec groupOps = groups.get(groupName);
+
+ // Create spec for this group if we have not encountered it yet in our iteration
+ if (groupOps == null) {
+ String fieldName =
+ delimiterPos < 0
+ ? group.substring(startPos)
+ : group.substring(startPos, delimiterPos);
+ ClassName className =
+ ClassName.get(
+ "org.tensorflow.op",
+ CaseFormat.LOWER_UNDERSCORE.to(
+ CaseFormat.UPPER_CAMEL, groupName.replace('.', '_'))
+ + "Ops");
+ groupOps =
+ new OpsSpec(groupName, fieldName, className, groupedMethods.get(groupName));
+ parentClass.subGroups.add(groupOps);
+ groups.put(groupName, groupOps);
+ }
+ parentClass = groupOps;
+ startPos = delimiterPos + 1;
+ } while (startPos > 0);
+ });
return groups.values();
}
private static TypeSpec buildGroupClass(OpsSpec spec) {
- //System.out.println("Generating " + spec.className + " class");
+ // System.out.println("Generating " + spec.className + " class");
MethodSpec.Builder ctorBuilder =
MethodSpec.constructorBuilder()
@@ -436,7 +465,8 @@ private static TypeSpec buildGroupClass(OpsSpec spec) {
Names.Ops)
.addMethods(spec.methods);
- MethodSpec.Builder opsBuilder = MethodSpec.methodBuilder("ops")
+ MethodSpec.Builder opsBuilder =
+ MethodSpec.methodBuilder("ops")
.addModifiers(Modifier.PUBLIC, Modifier.FINAL)
.returns(Names.Ops)
.addJavadoc("Get the parent {@link " + Names.Ops.simpleName() + "} object.")
@@ -449,21 +479,23 @@ private static TypeSpec buildGroupClass(OpsSpec spec) {
builder.addMethod(ctorBuilder.build());
builder.addField(
- FieldSpec.builder(Names.Scope, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
+ FieldSpec.builder(Names.Scope, "scope")
+ .addModifiers(Modifier.PRIVATE, Modifier.FINAL)
+ .build());
builder.addField(
- FieldSpec.builder(Names.Ops, "ops").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
+ FieldSpec.builder(Names.Ops, "ops").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
return builder.build();
}
private static TypeSpec buildTopClass(OpsSpec spec) {
- //System.out.println("Generating " + spec.className + " class");
+ // System.out.println("Generating " + spec.className + " class");
MethodSpec.Builder ctorBuilder =
MethodSpec.constructorBuilder()
- .addModifiers(Modifier.PRIVATE)
.addParameter(Names.Scope, "scope")
+ .addModifiers(Modifier.PRIVATE)
.addStatement("this.scope = scope", Names.Scope);
TypeSpec.Builder opsBuilder =
@@ -531,16 +563,16 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
.build());
opsBuilder.addMethod(
- MethodSpec.methodBuilder("withDevice")
- .addModifiers(Modifier.PUBLIC)
- .addParameter(Names.DeviceSpec, "deviceSpec")
- .returns(Names.Ops)
- .addStatement("return new Ops(scope.withDevice(deviceSpec))")
- .addJavadoc(
- "Returns an API that places the created operations on the device(s) matching the provided spec.\n\n"
- + "@see {@link $T#withDevice(DeviceSpec)}\n",
- Names.Scope)
- .build());
+ MethodSpec.methodBuilder("withDevice")
+ .addModifiers(Modifier.PUBLIC)
+ .addParameter(Names.DeviceSpec, "deviceSpec")
+ .returns(Names.Ops)
+ .addStatement("return new Ops(scope.withDevice(deviceSpec))")
+ .addJavadoc(
+ "Returns an API that places the created operations on the device(s) matching the provided spec.\n\n"
+ + "@see {@link $T#withDevice(DeviceSpec)}\n",
+ Names.Scope)
+ .build());
opsBuilder.addMethod(
MethodSpec.methodBuilder("withControlDependencies")
@@ -555,7 +587,9 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
.build());
opsBuilder.addField(
- FieldSpec.builder(Names.Scope, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
+ FieldSpec.builder(Names.Scope, "scope")
+ .addModifiers(Modifier.PRIVATE, Modifier.FINAL)
+ .build());
opsBuilder.addMethod(
MethodSpec.methodBuilder("scope")
@@ -570,7 +604,7 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
.addParameter(Names.ExecutionEnvironment, "env")
.returns(Names.Ops)
- .addStatement("return new Ops(new $T(env))", Names.Scope)
+ .addStatement("return new Ops(env.baseScope())", Names.Scope)
.addJavadoc(
"Creates an API for building operations in the provided execution environment\n")
.build());
@@ -579,7 +613,7 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
MethodSpec.methodBuilder("create")
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
.returns(Names.Ops)
- .addStatement("return new Ops(new $T($T.getDefault()))", Names.Scope, Names.EagerSession)
+ .addStatement("return create($T.getDefault())", Names.EagerSession)
.addJavadoc(
"Creates an API for building operations in the default eager execution environment\n\n"
+ "Invoking this method is equivalent to {@code Ops.create(EagerSession.getDefault())}.\n")
@@ -588,27 +622,39 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
return opsBuilder.build();
}
- private static void addGroupFields(TypeSpec.Builder classBuilder, MethodSpec.Builder ctorBuilder, List groups, boolean isTopClass) {
- groups.forEach(group -> {
- classBuilder.addField(
- FieldSpec.builder(group.className, group.fieldName)
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .build()
- );
- ctorBuilder.addStatement("$L = new $T(" + (isTopClass ? "this" : "ops") + ")", group.fieldName, group.className).build();
- });
+ private static void addGroupFields(
+ TypeSpec.Builder classBuilder,
+ MethodSpec.Builder ctorBuilder,
+ List groups,
+ boolean isTopClass) {
+ groups.forEach(
+ group -> {
+ classBuilder.addField(
+ FieldSpec.builder(group.className, group.fieldName)
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .build());
+ ctorBuilder
+ .addStatement(
+ "$L = new $T(" + (isTopClass ? "this" : "ops") + ")",
+ group.fieldName,
+ group.className)
+ .build();
+ });
}
private static AnnotationMirror getAnnotationMirror(Element element, Name annotationName) {
for (AnnotationMirror am : element.getAnnotationMirrors()) {
- if (((TypeElement)am.getAnnotationType().asElement()).getQualifiedName().equals(annotationName)) {
+ if (((TypeElement) am.getAnnotationType().asElement())
+ .getQualifiedName()
+ .equals(annotationName)) {
return am;
}
}
return null;
}
- private static AnnotationValue getAnnotationElementValue(String elementName, AnnotationMirror am) {
+ private static AnnotationValue getAnnotationElementValue(
+ String elementName, AnnotationMirror am) {
for (Map.Entry extends ExecutableElement, ? extends AnnotationValue> entry :
am.getElementValues().entrySet()) {
if (entry.getKey().getSimpleName().contentEquals(elementName)) {
@@ -623,7 +669,8 @@ private static String getAnnotationElementValueAsString(String elementName, Anno
return value != null ? value.getValue().toString() : "";
}
- private static boolean getAnnotationElementValueAsBoolean(String elementName, AnnotationMirror am, boolean defaultValue) {
+ private static boolean getAnnotationElementValueAsBoolean(
+ String elementName, AnnotationMirror am, boolean defaultValue) {
AnnotationValue value = getAnnotationElementValue(elementName, am);
return value != null ? Boolean.parseBoolean(value.toString()) : defaultValue;
}