diff --git a/license-header b/license-header deleted file mode 100644 index 12bf9309e9a..00000000000 --- a/license-header +++ /dev/null @@ -1,14 +0,0 @@ -/* - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ \ No newline at end of file diff --git a/pom.xml b/pom.xml index 76504524ae9..7288f3661b3 100644 --- a/pom.xml +++ b/pom.xml @@ -342,7 +342,23 @@ - ./license-header + +/* Copyright $YEAR The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ + diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 4044838de87..92e4cabdbd1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -354,10 +354,10 @@ public final class Ops { public final SparseOps sparse; - public final TpuOps tpu; - public final BitwiseOps bitwise; + public final TpuOps tpu; + public final MathOps math; public final AudioOps audio; @@ -385,8 +385,8 @@ private Ops(Scope scope) { random = new RandomOps(this); strings = new StringsOps(this); sparse = new SparseOps(this); - tpu = new TpuOps(this); bitwise = new BitwiseOps(this); + tpu = new TpuOps(this); math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); @@ -7884,7 +7884,7 @@ public final Scope scope() { * Creates an API for building operations in the provided execution environment */ public static Ops create(ExecutionEnvironment env) { - return new Ops(new Scope(env)); + return new Ops(env.baseScope()); } /** @@ -7893,6 +7893,6 @@ public static Ops create(ExecutionEnvironment env) { *

Invoking this method is equivalent to {@code Ops.create(EagerSession.getDefault())}. */ public static Ops create() { - return new Ops(new Scope(EagerSession.getDefault())); + return create(EagerSession.getDefault()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index dad842f7038..c5d67128406 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -1,18 +1,18 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_ContextOptionsSetAsync; @@ -29,6 +29,7 @@ import org.tensorflow.internal.c_api.TFE_ContextOptions; import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; @@ -112,7 +113,8 @@ public Options devicePlacementPolicy(DevicePlacementPolicy value) { * Configures the session based on the data found in the provided configuration. * * @param config a config protocol buffer - * @see config.proto + * @see config.proto */ public Options config(ConfigProto config) { this.config = config; @@ -306,6 +308,11 @@ public void checkInput(Op input) { } } + @Override + public Scope baseScope() { + return baseScope; + } + TFE_Context nativeHandle() { checkSession(); return nativeHandle; @@ -314,17 +321,16 @@ TFE_Context nativeHandle() { /** * Attach the list of native resources to this eager session scope. * - *

When the eager session is closed (i.e. by calling {@link #close()} explicitly or - * implicitly via try-with-resources), all native resources attached to the session will be - * released as well, unless so other references are {@link Pointer#retainReference() retaining} - * them.

+ *

When the eager session is closed (i.e. by calling {@link #close()} explicitly or implicitly + * via try-with-resources), all native resources attached to the session will be released as well, + * unless so other references are {@link Pointer#retainReference() retaining} them. * *

Attached resources can still be garbage collected though if their associated {@link Pointer} * is no longer reachable in Java, independently of their reference count. Therefore, it is * assumed that these resources are not required by the native library once the Java client no - * longer needs them.

+ * longer needs them. * - *

Attaching a resource already attached to this session will have no effect.

+ *

Attaching a resource already attached to this session will have no effect. * * @param resources resources to attach to the session */ @@ -339,14 +345,14 @@ void attach(Pointer... resources) { * Detach a list of resources from this eager session scope. * *

Detached native resources will prevent them to be automatically released when the session is - * closed.

+ * closed. * *

Note though that this method will decrement the reference count of each resources being - * detached, which may automatically released them if that count reaches 0. Therefore, - * invoking {@link Pointer#retainReference()} prior to this call on any resource that must remain - * valid after being detached might be required.

+ * detached, which may automatically released them if that count reaches 0. Therefore, invoking + * {@link Pointer#retainReference()} prior to this call on any resource that must remain valid + * after being detached might be required. * - *

Detaching a resource that is not attached to this session will have no effect.

+ *

Detaching a resource that is not attached to this session will have no effect. * * @param resources resources to detach from the session */ @@ -362,6 +368,8 @@ void detach(Pointer... resources) { private final WeakPointerScope nativeResources; private TFE_Context nativeHandle; + private final Scope baseScope = new Scope(this); + private EagerSession(Options options) { this.nativeResources = new WeakPointerScope(); this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config); @@ -381,7 +389,8 @@ private synchronized void doClose() { } } - private static TFE_Context allocate(boolean async, int devicePlacementPolicy, ConfigProto config) { + private static TFE_Context allocate( + boolean async, int devicePlacementPolicy, ConfigProto config) { try (PointerScope scope = new PointerScope()) { TFE_ContextOptions opts = TFE_ContextOptions.newContextOptions(); TF_Status status = TF_Status.newStatus(); @@ -390,7 +399,7 @@ private static TFE_Context allocate(boolean async, int devicePlacementPolicy, Co TFE_ContextOptionsSetConfig(opts, configBytes, configBytes.capacity(), status); status.throwExceptionIfNotOK(); } - TFE_ContextOptionsSetAsync(opts, (byte)(async ? 1 : 0)); + TFE_ContextOptionsSetAsync(opts, (byte) (async ? 1 : 0)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, devicePlacementPolicy); TFE_Context context = TFE_NewContext(opts, status); status.throwExceptionIfNotOK(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java index d5389bcd0ad..a18c7fff38b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java @@ -1,25 +1,24 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; -/** - * Defines an environment for creating and executing TensorFlow {@link Operation}s. - */ +/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */ public interface ExecutionEnvironment { enum Types { @@ -49,11 +48,12 @@ default boolean isOpEnabled(String opType) { } /** - * Checks that {@code input} is valid to use as an input in this execution environment. Throws {@link - * IllegalArgumentException} if not. + * Checks that {@code input} is valid to use as an input in this execution environment. Throws + * {@link IllegalArgumentException} if not. * * @param input The op to check - * @throws IllegalArgumentException if input can't be used as an input in this execution environment. + * @throws IllegalArgumentException if input can't be used as an input in this execution + * environment. */ void checkInput(Op input); @@ -71,4 +71,10 @@ default boolean isEager() { default boolean isGraph() { return environmentType() == Types.GRAPH; } + + /** + * Get the top level scope for this execution environment. Is cached, which is necessary to + * prevent name collisions. + */ + Scope baseScope(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 7f659b262a6..b69fe89da0a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -1,18 +1,18 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TF_AddGradientsWithPrefix; @@ -52,6 +52,7 @@ import org.tensorflow.ndarray.StdArrays; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Identity; import org.tensorflow.op.core.NoOp; @@ -63,7 +64,6 @@ import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; - /** * A data flow graph representing a TensorFlow computation. * @@ -74,18 +74,16 @@ */ public final class Graph implements ExecutionEnvironment, AutoCloseable { - /** - * Create an empty Graph. - */ + /** Create an empty Graph. */ public Graph() { nativeHandle = allocate(); + this.baseScope = new Scope(this); } - /** - * Create a Graph from an existing handle (takes ownership). - */ + /** Create a Graph from an existing handle (takes ownership). */ Graph(TF_Graph nativeHandle) { this.nativeHandle = nativeHandle; + this.baseScope = new Scope(this); } Graph(TF_Graph nativeHandle, SaverDef saverDef) { @@ -138,8 +136,8 @@ public GraphOperation operation(String name) { } /** - * Returns the operation (node in the Graph) with the provided name, or throws {@link IllegalArgumentException} if - * there isn't one. + * Returns the operation (node in the Graph) with the provided name, or throws {@link + * IllegalArgumentException} if there isn't one. * * @param name name of the operation to look for * @return operation in the graph with this name @@ -155,9 +153,9 @@ public GraphOperation operationOrThrow(String name) { /** * Returns the output with the provided name, or {@code null} if there is no such output. - *

Names should be of the - * format {@code /scope/op}, with an optional index: {@code /scope/op:1}. {@code 0} is used if the index is not - * specified. + * + *

Names should be of the format {@code /scope/op}, with an optional index: {@code + * /scope/op:1}. {@code 0} is used if the index is not specified. * * @param output the output to get * @return the output with this name, or null if there isn't one @@ -181,15 +179,17 @@ public Output output(String output) { } return new Output(operation, index); } catch (NumberFormatException e) { - throw new IllegalArgumentException("Could not get output for badly formatted output name: \"" + output + "\"", e); + throw new IllegalArgumentException( + "Could not get output for badly formatted output name: \"" + output + "\"", e); } } /** - * Returns the output with the provided name, or throws {@link IllegalArgumentException} if there isn't one. - *

Names should be of the - * format {@code /scope/op}, with an optional index: {@code /scope/op:1}. {@code 0} is used if the index is not - * specified. + * Returns the output with the provided name, or throws {@link IllegalArgumentException} if there + * isn't one. + * + *

Names should be of the format {@code /scope/op}, with an optional index: {@code + * /scope/op:1}. {@code 0} is used if the index is not specified. * * @param output the output to get * @return the output with this name @@ -220,16 +220,20 @@ private GraphOperation graphOp(Operand operand) { } /** - * Finds the operations used to produce {@code outputs}, assuming {@code inputs} are provided. Includes control dependencies. - *

- * Note that this function can easily return ops upstream of inputs as part of the body. Depending on your use, the - * returned body should probably be filtered for {@code Placeholder}s, at least. + * Finds the operations used to produce {@code outputs}, assuming {@code inputs} are provided. + * Includes control dependencies. + * + *

Note that this function can easily return ops upstream of inputs as part of the body. + * Depending on your use, the returned body should probably be filtered for {@code Placeholder}s, + * at least. * - * @param inputs the inputs of the subgraph. Must be from single output ops. May not be null. - * @param outputs the outputs of the subgraph. May not be null. - * @return the set of operations needed to calculate outputs from inputs, including outputs and inputs + * @param inputs the inputs of the subgraph. Must be from single output ops. May not be null. + * @param outputs the outputs of the subgraph. May not be null. + * @return the set of operations needed to calculate outputs from inputs, including outputs and + * inputs */ - public synchronized Set completeSubgraph(Set> inputs, Set> outputs) { + public synchronized Set completeSubgraph( + Set> inputs, Set> outputs) { if (inputs == null) { throw new IllegalArgumentException("Inputs can't be null."); @@ -245,7 +249,8 @@ public synchronized Set completeSubgraph(Set> inputs, for (Operand input : inputs) { if (input.op().numOutputs() > 1) { - throw new IllegalStateException("Only ops with one output are supported as subgraph inputs"); + throw new IllegalStateException( + "Only ops with one output are supported as subgraph inputs"); } GraphOperation op = graphOp(input); inputOps.add(op); @@ -277,15 +282,14 @@ public synchronized Set completeSubgraph(Set> inputs, currents.add(inputOp); } } - } return seen; } /** - * Get all ops directly or indirectly required to calculate {@code outputs} (not including {@code outputs}), including - * control dependencies. + * Get all ops directly or indirectly required to calculate {@code outputs} (not including {@code + * outputs}), including control dependencies. * * @param outputs the starting points of the traversal. * @return the ops needed to calculate {@code outputs}, not including {@code outputs} @@ -306,8 +310,8 @@ public Set subgraphToOps(Set outputs) { } /** - * Get all ops that use one of {@code inputs} directly or indirectly (not including {@code inputs}), including control - * dependencies. + * Get all ops that use one of {@code inputs} directly or indirectly (not including {@code + * inputs}), including control dependencies. * * @param inputs the starting points of the traversal. * @return the ops that depend on {@code inputs}, not including {@code inputs} @@ -328,8 +332,8 @@ public synchronized Set subgraphFromOps(Set inpu } /** - * Get all ops directly or indirectly required to calculate {@code outputs} (not including {@code outputs}), including - * control dependencies. + * Get all ops directly or indirectly required to calculate {@code outputs} (not including {@code + * outputs}), including control dependencies. * * @param outputs the starting points of the traversal. * @return the ops needed to calculate {@code outputs}, not including {@code outputs} @@ -339,8 +343,8 @@ public Set subgraphTo(Set> outputs) { } /** - * Get all ops that use one of {@code inputs} directly or indirectly (not including {@code inputs}), including control - * dependencies. + * Get all ops that use one of {@code inputs} directly or indirectly (not including {@code + * inputs}), including control dependencies. * * @param inputs the starting points of the traversal. * @return the ops that depend on {@code inputs}, not including {@code inputs} @@ -363,8 +367,8 @@ public synchronized Set subgraphFrom(Set> inputs) { * @param type of the Operation (i.e., identifies the computation to be performed) * @param name to refer to the created Operation in the graph. * @return an {@link OperationBuilder}, which will add the Operation to the graph when {@link - * OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked, then some resources may - * leak. + * OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked, + * then some resources may leak. */ @Override public GraphOperationBuilder opBuilder(String type, String name) { @@ -383,18 +387,26 @@ public Types environmentType() { public void checkInput(Op input) { if (input.env().isEager()) { throw new IllegalArgumentException( - "Input " + input + " was from an eager session, can't use in a graph. Use tf.constantOf(input.asTensor())"); + "Input " + + input + + " was from an eager session, can't use in a graph. Use tf.constantOf(input.asTensor())"); } if (input.env() != this) { - throw new IllegalArgumentException("Input " + input + " was from a different graph, can't use."); + throw new IllegalArgumentException( + "Input " + input + " was from a different graph, can't use."); } } + @Override + public Scope baseScope() { + return baseScope; + } + /** * Import a representation of a TensorFlow graph. * - *

The representation of the graph, referred to as a {@code GraphDef}, can be - * generated by {@link #toGraphDef()} and equivalents in other language APIs. + *

The representation of the graph, referred to as a {@code GraphDef}, can be generated by + * {@link #toGraphDef()} and equivalents in other language APIs. * * @param graphDef {@code GraphDef} proto to import * @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph. @@ -442,19 +454,18 @@ public synchronized void addInitializer(Op initializer) { initializers.add(initializer); } - /** - * Returns all initializers added to the graph via {@link #addInitializer(Op)} - */ + /** Returns all initializers added to the graph via {@link #addInitializer(Op)} */ public List initializers() { return Collections.unmodifiableList(initializers); } /** - * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., {@code d(y_1 + y_2 - * + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., + * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} * *

{@code dx} are used as initial gradients (which represent the symbolic partial derivatives - * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}. + * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of + * {@code y}. * *

If {@code dx} is null, the implementation will use dx of {@link * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}. @@ -464,8 +475,8 @@ public List initializers() { * *

If {@code prefix} is null, then one will be chosen automatically. * - * @param prefix unique string prefix applied before the names of nodes added to the graph to compute gradients. If - * null, a default one will be chosen. + * @param prefix unique string prefix applied before the names of nodes added to the graph to + * compute gradients. If null, a default one will be chosen. * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} @@ -515,8 +526,11 @@ public Output[] addGradients(String prefix, Output[] y, Output[] x, Out dxIndices); int ndy = dyHandlesAndIndices.length >> 1; if (ndy != dy.length) { - throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length - + " were expected"); + throw new IllegalStateException( + String.valueOf(ndy) + + " gradients were added to the graph when " + + dy.length + + " were expected"); } for (int i = 0, j = ndy; i < ndy; ++i, ++j) { GraphOperation op = new GraphOperation(this, (TF_Operation) dyHandlesAndIndices[i]); @@ -527,23 +541,24 @@ public Output[] addGradients(String prefix, Output[] y, Output[] x, Out } /** - * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., {@code dy/dx_1, - * dy/dx_2...} - *

- * This is a simplified version of {@link #addGradients(String, Output[], Output[], Output[])} where {@code y} is a - * single output, {@code dx} is null and {@code prefix} is null. + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., + * {@code dy/dx_1, dy/dx_2...} + * + *

This is a simplified version of {@link #addGradients(String, Output[], Output[], Output[])} + * where {@code y} is a single output, {@code dx} is null and {@code prefix} is null. * * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @return the partial derivatives {@code dy} with the size of {@code x} */ public Output[] addGradients(Output y, Output[] x) { - return addGradients(null, new Output[]{y}, x, null); + return addGradients(null, new Output[] {y}, x, null); } /** - * Used to instantiate an abstract class which overrides the buildSubgraph method to build a conditional or body - * subgraph for a while loop. After Java 8, this can alternatively be used to create a lambda for the same purpose. + * Used to instantiate an abstract class which overrides the buildSubgraph method to build a + * conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to + * create a lambda for the same purpose. * *

To be used when calling {@link #whileLoop(Output[], * org.tensorflow.Graph.WhileSubgraphBuilder, org.tensorflow.Graph.WhileSubgraphBuilder, String)} @@ -558,7 +573,9 @@ public Output[] addGradients(Output y, Output[] x) { * } * }; * + * * Example usage (after Java 8): + * *

    * WhileSubgraphBuilder bodyGraphBuilder = (bodyGraph, bodyInputs, bodyOutputs) -> { //
    *   build body subgraph
@@ -657,13 +674,15 @@ public Output[] whileLoop(
   }
 
   /**
-   * Return the {@link SaverDef} instance used to save the state of all variables present in this graph.
+   * Return the {@link SaverDef} instance used to save the state of all variables present in this
+   * graph.
    *
-   * 

The first time this method is called it builds the {@link SaverDef}. If this graph already contains a - * "save/restore_all" operation then it is assumed to contain all necessary saving and restoring operations. If that - * operation does not exist then the graph is mutated to add all the nodes necessary to save and restore the state of - * the graph. Consequently, any variables that are added to the graph after this call will not be saved nor restored - * using this {@link SaverDef}. + *

The first time this method is called it builds the {@link SaverDef}. If this graph already + * contains a "save/restore_all" operation then it is assumed to contain all necessary saving and + * restoring operations. If that operation does not exist then the graph is mutated to add all the + * nodes necessary to save and restore the state of the graph. Consequently, any variables that + * are added to the graph after this call will not be saved nor restored using this {@link + * SaverDef}. * * @return a {@link SaverDef} instance */ @@ -678,11 +697,12 @@ synchronized SaverDef saverDef() { // regenerate SaverDef without mutating. The names mirror // the python implementation for compatibility. // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py - saverDef = SaverDef.newBuilder() - .setFilenameTensorName("save/filename") - .setSaveTensorName("save/control_dependency") - .setRestoreOpName("save/restore_all") - .build(); + saverDef = + SaverDef.newBuilder() + .setFilenameTensorName("save/filename") + .setSaveTensorName("save/control_dependency") + .setRestoreOpName("save/restore_all") + .build(); } } return saverDef; @@ -692,6 +712,7 @@ synchronized SaverDef saverDef() { private TF_Graph nativeHandle; private int refcount = 0; private SaverDef saverDef; + private final Scope baseScope; private final List initializers = new ArrayList<>(); @@ -757,7 +778,9 @@ private final void advance() { try { Object[] nativeReturn = nextOperation(reference.nativeHandle(), this.position); - if (nativeReturn != null && nativeReturn[0] != null && !((TF_Operation) nativeReturn[0]).isNull()) { + if (nativeReturn != null + && nativeReturn[0] != null + && !((TF_Operation) nativeReturn[0]).isNull()) { this.operation = new GraphOperation(this.graph, (TF_Operation) nativeReturn[0]); this.position = (Integer) nativeReturn[1]; } @@ -863,14 +886,21 @@ private static GraphDef toGraphDef(TF_Graph handle) { } } - static void resolveOutputs(String type, TF_Operation[] srcOps, - int[] srcIndices, TF_Output dst, int n) { + static void resolveOutputs( + String type, TF_Operation[] srcOps, int[] srcIndices, TF_Output dst, int n) { if (srcOps.length != n) { - throw new IllegalArgumentException("expected " + n + ", got " + srcOps.length + " " + type + " Operations"); + throw new IllegalArgumentException( + "expected " + n + ", got " + srcOps.length + " " + type + " Operations"); } if (srcIndices.length != n) { throw new IllegalArgumentException( - "expected " + n + ", got " + srcIndices.length + " " + type + " Operation output indices"); + "expected " + + n + + ", got " + + srcIndices.length + + " " + + type + + " Operation output indices"); } for (int i = 0; i < n; ++i) { if (srcOps[i] == null || srcOps[i].isNull()) { @@ -905,7 +935,8 @@ private static Object[] addGradients( resolveOutputs("x", outputHandles, outputIndices, x, nx); if (gradInputHandles != null) { if (gradInputHandles.length != ny) { - throw new IllegalArgumentException("expected " + ny + ", got " + gradInputHandles.length + " handles"); + throw new IllegalArgumentException( + "expected " + ny + ", got " + gradInputHandles.length + " handles"); } dx = new TF_Output(ny); resolveOutputs("dx", gradInputHandles, gradInputIndices, dx, ny); @@ -961,9 +992,13 @@ private static Object[] whileLoop( condOutputIndices[0] = condOutputOutput.index(); Object[] condOutputHandlesAndIndices = - buildSubgraph(condGraphBuilder, params.cond_graph(), - condInputHandles, condInputIndices, - condOutputHandles, condOutputIndices); + buildSubgraph( + condGraphBuilder, + params.cond_graph(), + condInputHandles, + condInputIndices, + condOutputHandles, + condOutputIndices); // build body subgraph TF_Output bodyInputsOutput = params.body_inputs(); @@ -980,22 +1015,28 @@ private static Object[] whileLoop( } Object[] bodyOutputHandlesAndIndices = - buildSubgraph(bodyGraphBuilder, params.body_graph(), - bodyInputHandles, bodyInputIndices, - bodyOutputHandles, bodyOutputIndices); - - if (condOutputHandlesAndIndices == null || - bodyOutputHandlesAndIndices == null) { + buildSubgraph( + bodyGraphBuilder, + params.body_graph(), + bodyInputHandles, + bodyInputIndices, + bodyOutputHandles, + bodyOutputIndices); + + if (condOutputHandlesAndIndices == null || bodyOutputHandlesAndIndices == null) { return null; } // set cond_output param to output of the conditional subgraph - condOutputOutput.oper((TF_Operation) condOutputHandlesAndIndices[0]) + condOutputOutput + .oper((TF_Operation) condOutputHandlesAndIndices[0]) .index((Integer) condOutputHandlesAndIndices[1]); // set body_outputs param to outputs of the body subgraph for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) { - bodyOutputsOutput.position(i).oper((TF_Operation) bodyOutputHandlesAndIndices[i]) + bodyOutputsOutput + .position(i) + .oper((TF_Operation) bodyOutputHandlesAndIndices[i]) .index((Integer) bodyOutputHandlesAndIndices[j]); } @@ -1042,20 +1083,12 @@ private static SaverDef addVariableSaver(Graph graph) { Operand varSlices = tf.zerosLike(varNamesTensor); Placeholder saveFilename = tf.withName("filename").placeholder(TString.class); - Save saveVariables = tf.train.save( - saveFilename, - varNamesTensor, - varSlices, - varOutputs - ); - Identity id = tf.withControlDependencies(Arrays.asList(saveFilename, saveVariables)) - .withName("control_dependency").identity(saveFilename); - Restore restoreVariables = tf.train.restore( - saveFilename, - varNamesTensor, - varSlices, - varTypes - ); + Save saveVariables = tf.train.save(saveFilename, varNamesTensor, varSlices, varOutputs); + Identity id = + tf.withControlDependencies(Arrays.asList(saveFilename, saveVariables)) + .withName("control_dependency") + .identity(saveFilename); + Restore restoreVariables = tf.train.restore(saveFilename, varNamesTensor, varSlices, varTypes); List restoreOps = new ArrayList<>(varOutputs.size()); for (int i = 0; i < varOutputs.size(); ++i) { restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i))); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java index 2e84cac1ac7..903a12f66b2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java @@ -1,23 +1,26 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow.op; import java.util.HashMap; import java.util.Map; +import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.tensorflow.ExecutionEnvironment; +import org.tensorflow.Graph; /** * A class to manage scoped (hierarchical) names for operators. @@ -36,12 +39,12 @@ */ final class NameScope { - NameScope withSubScope(String scopeName) { + NameScope withSubScope(String scopeName, ExecutionEnvironment env) { checkPattern(NAME_REGEX, scopeName); // Override with opName if it exists. String actualName = (opName != null) ? opName : scopeName; String newPrefix = fullyQualify(makeUnique(actualName)); - return new NameScope(newPrefix, null, null); + return new NameScope(newPrefix, null, null).withUsedFrom(env); } NameScope withName(String name) { @@ -50,6 +53,46 @@ NameScope withName(String name) { return new NameScope(opPrefix, name, ids); } + private static final Pattern NAME_PATTERN = Pattern.compile("(.+)_(\\d+)", Pattern.DOTALL); + + /** "Import" used names from a graph. Useful when adding to a loaded graph. */ + private NameScope withUsedFrom(ExecutionEnvironment env) { + + if (env instanceof Graph) { + ((Graph) env) + .operations() + .forEachRemaining( + op -> { + if (op.name().startsWith(opPrefix != null ? opPrefix : "")) { + String name = op.name(); + + if (opPrefix != null) { + name = name.substring(opPrefix.length() + 1); + } + + if (!name.contains("/")) { + Matcher matcher = NAME_PATTERN.matcher(name); + if (matcher.find()) { + String realName = matcher.group(1); + int num = Integer.parseInt(matcher.group(2)) + 1; + + if (!(ids.containsKey(realName) && ids.get(realName) > num)) { + ids.put(realName, num); + } + } else { + if (!ids.containsKey(name)) { + ids.put(name, 1); + } else { + ids.put(name, ids.get(name) + 1); + } + } + } + } + }); + } + return this; + } + String makeOpName(String name) { checkPattern(NAME_REGEX, name); // Override with opName if it exists. @@ -62,9 +105,12 @@ String makeOpName(String name) { * *

A root-level namescope generates operator names with no components, like {@code Const_72} * and {@code result}. + * + * @param env */ - NameScope() { + NameScope(ExecutionEnvironment env) { this(null, null, null); + withUsedFrom(env); } private NameScope(String opPrefix, String opName, Map ids) { @@ -120,6 +166,13 @@ private String fullyQualify(String name) { // instance mapped to the next available numeric suffix for it. private final Map ids; + static boolean isValidName(String name) { + if (name == null) { + return false; + } + return NAME_REGEX.matcher(name).matches(); + } + private static void checkPattern(Pattern pattern, String name) { if (name == null) { throw new IllegalArgumentException("Names cannot be null"); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java index 85e283d9260..2aef70f6af0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java @@ -1,18 +1,18 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow.op; import java.util.ArrayList; @@ -21,7 +21,8 @@ import org.tensorflow.OperationBuilder; /** - * Manages groups of related properties when creating Tensorflow Operations, such as a common name prefix. + * Manages groups of related properties when creating Tensorflow Operations, such as a common name + * prefix. * *

A {@code Scope} is a container for common properties applied to TensorFlow Ops. Normal user * code initializes a {@code Scope} and provides it to Operation building classes. For example: @@ -80,15 +81,16 @@ public final class Scope { /** * Create a new top-level scope. * + *

For internal use only, use {@link ExecutionEnvironment#baseScope()} if you need a + * base level scope. + * * @param env The execution environment used by the scope. */ public Scope(ExecutionEnvironment env) { - this(env, new NameScope(), new ArrayList<>(), DeviceSpec.newBuilder().build()); + this(env, new NameScope(env), new ArrayList<>(), DeviceSpec.newBuilder().build()); } - /** - * Returns the execution environment used by this scope. - */ + /** Returns the execution environment used by this scope. */ public ExecutionEnvironment env() { return env; } @@ -97,7 +99,8 @@ public ExecutionEnvironment env() { * Returns a new scope where added operations will have the provided name prefix. * *

Ops created with this scope will have {@code name/childScopeName/} as the prefix. The actual - * name will be unique in the returned scope. All other properties are inherited from the current scope. + * name will be unique in the returned scope. All other properties are inherited from the current + * scope. * *

The child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*} * @@ -106,7 +109,8 @@ public ExecutionEnvironment env() { * @throws IllegalArgumentException if the name is invalid */ public Scope withSubScope(String childScopeName) { - return new Scope(env, nameScope.withSubScope(childScopeName), controlDependencies, deviceSpec); + return new Scope( + env, nameScope.withSubScope(childScopeName, env), controlDependencies, deviceSpec); } /** @@ -126,29 +130,34 @@ public Scope withName(String opName) { } /** - * Returns a new scope where added operations will be prefixed by this scope's op name - * (set by {@link #withName(String)}), or the given default if it is unset. This is intended to be used for - * composite ops. + * Returns a new scope where added operations will be prefixed by this scope's op name (set by + * {@link #withName(String)}), or the given default if it is unset. This is intended to be used + * for composite ops. * - *

Ops created with this scope will have {@code name/opName/} as the prefix. The actual - * name will be unique in the returned scope. All other properties are inherited from the current + *

Ops created with this scope will have {@code name/opName/} as the prefix. The actual name + * will be unique in the returned scope. All other properties are inherited from the current * scope. * - *

The default child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*} + *

The default child scope name must match the regular expression {@code + * [A-Za-z0-9.][A-Za-z0-9_.\-]*} * * @param defaultName name of the sub scope if this scope's name hasn't been set. * @return a new subscope * @throws IllegalArgumentException if the name is invalid */ - public Scope withNameAsSubScope(String defaultName){ - return new Scope(env, nameScope.withSubScope(nameScope.makeOpName(defaultName)), controlDependencies, deviceSpec); + public Scope withNameAsSubScope(String defaultName) { + return new Scope( + env, + nameScope.withSubScope(nameScope.makeOpName(defaultName), env), + controlDependencies, + deviceSpec); } /** * Return a new scope that uses the provided device specification for an op. * - *

Operations created within this scope will place the created operations on the device(s) matching the provided - * spec. + *

Operations created within this scope will place the created operations on the device(s) + * matching the provided spec. * * @param deviceSpec device specification for an operator in the returned scope * @return a new Scope that uses opName for operations. @@ -170,8 +179,8 @@ public Scope withDevice(DeviceSpec deviceSpec) { * }

* *

Note: if you provide a composite operator building class (i.e, a class that creates a - * set of related operations by calling other operator building code), the provided name will act as a subscope to all - * underlying operators. + * set of related operations by calling other operator building code), the provided name will act + * as a subscope to all underlying operators. * * @param defaultName name for the underlying operator. * @return unique name for the operator. @@ -181,8 +190,15 @@ public String makeOpName(String defaultName) { return nameScope.makeOpName(defaultName); } + public static boolean isValidOpName(String name) { + return NameScope.isValidName(name); + } + private Scope( - ExecutionEnvironment env, NameScope nameScope, Iterable controlDependencies, DeviceSpec deviceSpec) { + ExecutionEnvironment env, + NameScope nameScope, + Iterable controlDependencies, + DeviceSpec deviceSpec) { this.env = env; this.nameScope = nameScope; this.controlDependencies = controlDependencies; @@ -206,8 +222,8 @@ public Scope withControlDependencies(Iterable controls) { } /** - * Applies device specification and adds each Operand in controlDependencies as a control input to the provided - * builder. + * Applies device specification and adds each Operand in controlDependencies as a control input to + * the provided builder. * * @param builder OperationBuilder to add control inputs and device specification to */ @@ -233,9 +249,7 @@ public OperationBuilder applyControlDependencies(OperationBuilder builder) { private final NameScope nameScope; private final DeviceSpec deviceSpec; - /** - * Returns device string from the scope. - */ + /** Returns device string from the scope. */ public String getDeviceString() { return deviceSpec.toString(); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java index 62881dcee8c..84eabd3da1a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java @@ -1,18 +1,18 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow.op; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -29,6 +29,24 @@ /** Unit tests for {@link org.tensorflow.op.Scope}. */ public class ScopeTest { + @Test + public void testSeparateOps() { + try (Graph g = new Graph()) { + Ops tf1 = Ops.create(g); + Ops tf2 = Ops.create(g); + + tf1.constant(2); + tf1.withName("Constant2").constant(2); + tf1.withSubScope("Scope").constant(2); + tf1.withSubScope("Scope").withName("Constant4").constant(2); + + tf2.constant(2); + tf2.withName("Constant2").constant(2); + tf2.withSubScope("Scope").constant(2); + tf2.withSubScope("Scope").withName("Constant4").constant(2); + } + } + @Test public void basicNames() { try (Graph g = new Graph()) { @@ -168,9 +186,9 @@ public void composite() { // assertNotNull(g.operation("variance/zero")); // Verify correct results as well. - TInt32 result = (TInt32)sess.runner().fetch(var1.output()).run().get(0); + TInt32 result = (TInt32) sess.runner().fetch(var1.output()).run().get(0); assertEquals(21704, result.getInt()); - result = (TInt32)sess.runner().fetch(var2.output()).run().get(0); + result = (TInt32) sess.runner().fetch(var2.output()).run().get(0); assertEquals(21704, result.getInt()); } } diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java index bea817e9011..1b1d5cb0fb3 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java @@ -1,19 +1,18 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow.processor.operator; import com.github.javaparser.ast.comments.JavadocComment; @@ -156,26 +155,28 @@ public Set getSupportedAnnotationTypes() { } private static class OpsSpec { - private static final Comparator PARAMETER_SPEC_COMPARATOR = (o1, o2) -> { - if (o1.parameters.size() > o2.parameters.size()) { - return 1; - } - if (o1.parameters.size() < o2.parameters.size()) { - return -1; - } - List firstParams = o1.parameters; - List secondParams = o2.parameters; - for (int i = 0; i < firstParams.size(); i++) { - ParameterSpec first = firstParams.get(i); - ParameterSpec second = secondParams.get(i); - int compare = first.name.compareTo(second.name); - if (compare != 0) { - return compare; - } - } - return 0; - }; - private static final Comparator METHOD_SPEC_COMPARATOR = Comparator.comparing((MethodSpec m) -> m.name).thenComparing(PARAMETER_SPEC_COMPARATOR); + private static final Comparator PARAMETER_SPEC_COMPARATOR = + (o1, o2) -> { + if (o1.parameters.size() > o2.parameters.size()) { + return 1; + } + if (o1.parameters.size() < o2.parameters.size()) { + return -1; + } + List firstParams = o1.parameters; + List secondParams = o2.parameters; + for (int i = 0; i < firstParams.size(); i++) { + ParameterSpec first = firstParams.get(i); + ParameterSpec second = secondParams.get(i); + int compare = first.name.compareTo(second.name); + if (compare != 0) { + return compare; + } + } + return 0; + }; + private static final Comparator METHOD_SPEC_COMPARATOR = + Comparator.comparing((MethodSpec m) -> m.name).thenComparing(PARAMETER_SPEC_COMPARATOR); final String groupName; final String fieldName; @@ -183,7 +184,8 @@ private static class OpsSpec { final List methods; final List subGroups = new ArrayList<>(); - OpsSpec(String groupName, String fieldName, ClassName className, Collection methods) { + OpsSpec( + String groupName, String fieldName, ClassName className, Collection methods) { this.groupName = groupName; this.fieldName = fieldName; this.className = className; @@ -227,11 +229,11 @@ private void error(Element e, String message, Object... args) { private void write(TypeSpec spec) { try { JavaFile.builder("org.tensorflow.op", spec) - .addFileComment(LICENSE) - .addFileComment("\nThis class has been generated, DO NOT EDIT!\n") - .skipJavaLangImports(true) - .build() - .writeTo(filer); + .addFileComment(LICENSE) + .addFileComment("\nThis class has been generated, DO NOT EDIT!\n") + .skipJavaLangImports(true) + .build() + .writeTo(filer); } catch (IOException e) { throw new AssertionError(e); } @@ -262,7 +264,7 @@ private boolean collectOpsMethods( result = false; continue; } - collectOpMethods(groupedMethods, (TypeElement)e, annotation); + collectOpMethods(groupedMethods, (TypeElement) e, annotation); } return result; } @@ -281,7 +283,8 @@ private void collectOpMethods( String opGroup = getAnnotationElementValueAsString("group", operatorAnnot); String opName = getAnnotationElementValueAsString("name", operatorAnnot); if (Strings.isNullOrEmpty(opName)) { - opName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, ClassName.get(opClass).simpleName()); + opName = + CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, ClassName.get(opClass).simpleName()); } // Build an endpoint for each method annotated with @Endpoint, which takes in parameter a scope // and, optionally, a list of arguments @@ -293,11 +296,17 @@ private void collectOpMethods( throw new IllegalArgumentException( "Endpoint " + opMethod + " of class " + opClass + " must be static and public"); } - if (opMethod.getParameters().isEmpty() || - !((TypeElement)types.asElement(opMethod.getParameters().get(0).asType())).getQualifiedName() + if (opMethod.getParameters().isEmpty() + || !((TypeElement) types.asElement(opMethod.getParameters().get(0).asType())) + .getQualifiedName() .equals(elements.getName(Names.Scope.toString()))) { throw new IllegalArgumentException( - "Endpoint " + opMethod + " of class " + opClass + " must take an instance of " + Names.Scope + "Endpoint " + + opMethod + + " of class " + + opClass + + " must take an instance of " + + Names.Scope + " as its first parameter"); } String endpointGroup = getAnnotationElementValueAsString("group", endpointAnnot); @@ -311,15 +320,19 @@ private void collectOpMethods( boolean describeByClass = getAnnotationElementValueAsBoolean("describeByClass", endpointAnnot, false); boolean deprecated = opMethod.getAnnotation(Deprecated.class) != null || opClassDeprecated; - MethodSpec method = buildOpMethod(endpointName, opClass, opMethod, describeByClass, deprecated); + MethodSpec method = + buildOpMethod(endpointName, opClass, opMethod, describeByClass, deprecated); groupedMethods.put(endpointGroup, method); } } } private MethodSpec buildOpMethod( - String methodName, TypeElement opClass, ExecutableElement endpointMethod, - boolean describeByClass, boolean deprecated) { + String methodName, + TypeElement opClass, + ExecutableElement endpointMethod, + boolean describeByClass, + boolean deprecated) { MethodSpec.Builder builder = MethodSpec.methodBuilder(methodName) .addModifiers(Modifier.PUBLIC) @@ -341,9 +354,7 @@ private MethodSpec buildOpMethod( if (!NoType.class.isAssignableFrom(endpointMethod.getReturnType().getClass())) { call.append("return "); } - call.append("$T.") - .append(endpointMethod.getSimpleName()) - .append("(scope"); + call.append("$T.").append(endpointMethod.getSimpleName()).append("(scope"); boolean first = true; for (VariableElement param : endpointMethod.getParameters()) { ParameterSpec p = ParameterSpec.get(param); @@ -374,50 +385,68 @@ private String buildOpMethodJavadoc( // Copy all endpoint method tags to the description, except for the `scope` parameter which // will be inferred by the Ops class - methodJavadoc.getBlockTags().forEach(t -> { - if (!(t.getTagName().equals("param") && t.getName().map(s -> s.equals("scope")).orElse(false))) { - javadoc.addBlockTag(t); - } - }); + methodJavadoc + .getBlockTags() + .forEach( + t -> { + if (!(t.getTagName().equals("param") + && t.getName().map(s -> s.equals("scope")).orElse(false))) { + javadoc.addBlockTag(t); + } + }); return javadoc.toText(); } - private static Collection collectGroupOps(OpsSpec ops, Multimap groupedMethods) { + private static Collection collectGroupOps( + OpsSpec ops, Multimap groupedMethods) { Map groups = new HashMap<>(); - // The `group` label added in the `@Operator` annotation has the same syntax as a package name, which (in most - // case) consists of a simple label but could also be a deeper tree, like `linalg.sparse`. In this case, - // the `LinalgSparseOps` group should be added as the `sparse` field of the `LinalgOps` group, and the latter + // The `group` label added in the `@Operator` annotation has the same syntax as a package name, + // which (in most + // case) consists of a simple label but could also be a deeper tree, like `linalg.sparse`. In + // this case, + // the `LinalgSparseOps` group should be added as the `sparse` field of the `LinalgOps` group, + // and the latter // should be added as the `linalg` field of the `Ops` root class. - groupedMethods.keys().forEach(group -> { - OpsSpec parentClass = ops; - int startPos = 0; - do { - int delimiterPos = group.indexOf('.', startPos); - String groupName = delimiterPos < 0 ? group : group.substring(0, delimiterPos); - OpsSpec groupOps = groups.get(groupName); - - // Create spec for this group if we have not encountered it yet in our iteration - if (groupOps == null) { - String fieldName = delimiterPos < 0 ? - group.substring(startPos) : group.substring(startPos, delimiterPos); - ClassName className = ClassName.get("org.tensorflow.op", - CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, groupName.replace('.', '_')) + "Ops"); - groupOps = new OpsSpec(groupName, fieldName, className, groupedMethods.get(groupName)); - parentClass.subGroups.add(groupOps); - groups.put(groupName, groupOps); - } - parentClass = groupOps; - startPos = delimiterPos + 1; - } while (startPos > 0); - }); + groupedMethods + .keys() + .forEach( + group -> { + OpsSpec parentClass = ops; + int startPos = 0; + do { + int delimiterPos = group.indexOf('.', startPos); + String groupName = delimiterPos < 0 ? group : group.substring(0, delimiterPos); + OpsSpec groupOps = groups.get(groupName); + + // Create spec for this group if we have not encountered it yet in our iteration + if (groupOps == null) { + String fieldName = + delimiterPos < 0 + ? group.substring(startPos) + : group.substring(startPos, delimiterPos); + ClassName className = + ClassName.get( + "org.tensorflow.op", + CaseFormat.LOWER_UNDERSCORE.to( + CaseFormat.UPPER_CAMEL, groupName.replace('.', '_')) + + "Ops"); + groupOps = + new OpsSpec(groupName, fieldName, className, groupedMethods.get(groupName)); + parentClass.subGroups.add(groupOps); + groups.put(groupName, groupOps); + } + parentClass = groupOps; + startPos = delimiterPos + 1; + } while (startPos > 0); + }); return groups.values(); } private static TypeSpec buildGroupClass(OpsSpec spec) { - //System.out.println("Generating " + spec.className + " class"); + // System.out.println("Generating " + spec.className + " class"); MethodSpec.Builder ctorBuilder = MethodSpec.constructorBuilder() @@ -436,7 +465,8 @@ private static TypeSpec buildGroupClass(OpsSpec spec) { Names.Ops) .addMethods(spec.methods); - MethodSpec.Builder opsBuilder = MethodSpec.methodBuilder("ops") + MethodSpec.Builder opsBuilder = + MethodSpec.methodBuilder("ops") .addModifiers(Modifier.PUBLIC, Modifier.FINAL) .returns(Names.Ops) .addJavadoc("Get the parent {@link " + Names.Ops.simpleName() + "} object.") @@ -449,21 +479,23 @@ private static TypeSpec buildGroupClass(OpsSpec spec) { builder.addMethod(ctorBuilder.build()); builder.addField( - FieldSpec.builder(Names.Scope, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); + FieldSpec.builder(Names.Scope, "scope") + .addModifiers(Modifier.PRIVATE, Modifier.FINAL) + .build()); builder.addField( - FieldSpec.builder(Names.Ops, "ops").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); + FieldSpec.builder(Names.Ops, "ops").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); return builder.build(); } private static TypeSpec buildTopClass(OpsSpec spec) { - //System.out.println("Generating " + spec.className + " class"); + // System.out.println("Generating " + spec.className + " class"); MethodSpec.Builder ctorBuilder = MethodSpec.constructorBuilder() - .addModifiers(Modifier.PRIVATE) .addParameter(Names.Scope, "scope") + .addModifiers(Modifier.PRIVATE) .addStatement("this.scope = scope", Names.Scope); TypeSpec.Builder opsBuilder = @@ -531,16 +563,16 @@ private static TypeSpec buildTopClass(OpsSpec spec) { .build()); opsBuilder.addMethod( - MethodSpec.methodBuilder("withDevice") - .addModifiers(Modifier.PUBLIC) - .addParameter(Names.DeviceSpec, "deviceSpec") - .returns(Names.Ops) - .addStatement("return new Ops(scope.withDevice(deviceSpec))") - .addJavadoc( - "Returns an API that places the created operations on the device(s) matching the provided spec.\n\n" - + "@see {@link $T#withDevice(DeviceSpec)}\n", - Names.Scope) - .build()); + MethodSpec.methodBuilder("withDevice") + .addModifiers(Modifier.PUBLIC) + .addParameter(Names.DeviceSpec, "deviceSpec") + .returns(Names.Ops) + .addStatement("return new Ops(scope.withDevice(deviceSpec))") + .addJavadoc( + "Returns an API that places the created operations on the device(s) matching the provided spec.\n\n" + + "@see {@link $T#withDevice(DeviceSpec)}\n", + Names.Scope) + .build()); opsBuilder.addMethod( MethodSpec.methodBuilder("withControlDependencies") @@ -555,7 +587,9 @@ private static TypeSpec buildTopClass(OpsSpec spec) { .build()); opsBuilder.addField( - FieldSpec.builder(Names.Scope, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); + FieldSpec.builder(Names.Scope, "scope") + .addModifiers(Modifier.PRIVATE, Modifier.FINAL) + .build()); opsBuilder.addMethod( MethodSpec.methodBuilder("scope") @@ -570,7 +604,7 @@ private static TypeSpec buildTopClass(OpsSpec spec) { .addModifiers(Modifier.PUBLIC, Modifier.STATIC) .addParameter(Names.ExecutionEnvironment, "env") .returns(Names.Ops) - .addStatement("return new Ops(new $T(env))", Names.Scope) + .addStatement("return new Ops(env.baseScope())", Names.Scope) .addJavadoc( "Creates an API for building operations in the provided execution environment\n") .build()); @@ -579,7 +613,7 @@ private static TypeSpec buildTopClass(OpsSpec spec) { MethodSpec.methodBuilder("create") .addModifiers(Modifier.PUBLIC, Modifier.STATIC) .returns(Names.Ops) - .addStatement("return new Ops(new $T($T.getDefault()))", Names.Scope, Names.EagerSession) + .addStatement("return create($T.getDefault())", Names.EagerSession) .addJavadoc( "Creates an API for building operations in the default eager execution environment\n\n" + "

Invoking this method is equivalent to {@code Ops.create(EagerSession.getDefault())}.\n") @@ -588,27 +622,39 @@ private static TypeSpec buildTopClass(OpsSpec spec) { return opsBuilder.build(); } - private static void addGroupFields(TypeSpec.Builder classBuilder, MethodSpec.Builder ctorBuilder, List groups, boolean isTopClass) { - groups.forEach(group -> { - classBuilder.addField( - FieldSpec.builder(group.className, group.fieldName) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .build() - ); - ctorBuilder.addStatement("$L = new $T(" + (isTopClass ? "this" : "ops") + ")", group.fieldName, group.className).build(); - }); + private static void addGroupFields( + TypeSpec.Builder classBuilder, + MethodSpec.Builder ctorBuilder, + List groups, + boolean isTopClass) { + groups.forEach( + group -> { + classBuilder.addField( + FieldSpec.builder(group.className, group.fieldName) + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .build()); + ctorBuilder + .addStatement( + "$L = new $T(" + (isTopClass ? "this" : "ops") + ")", + group.fieldName, + group.className) + .build(); + }); } private static AnnotationMirror getAnnotationMirror(Element element, Name annotationName) { for (AnnotationMirror am : element.getAnnotationMirrors()) { - if (((TypeElement)am.getAnnotationType().asElement()).getQualifiedName().equals(annotationName)) { + if (((TypeElement) am.getAnnotationType().asElement()) + .getQualifiedName() + .equals(annotationName)) { return am; } } return null; } - private static AnnotationValue getAnnotationElementValue(String elementName, AnnotationMirror am) { + private static AnnotationValue getAnnotationElementValue( + String elementName, AnnotationMirror am) { for (Map.Entry entry : am.getElementValues().entrySet()) { if (entry.getKey().getSimpleName().contentEquals(elementName)) { @@ -623,7 +669,8 @@ private static String getAnnotationElementValueAsString(String elementName, Anno return value != null ? value.getValue().toString() : ""; } - private static boolean getAnnotationElementValueAsBoolean(String elementName, AnnotationMirror am, boolean defaultValue) { + private static boolean getAnnotationElementValueAsBoolean( + String elementName, AnnotationMirror am, boolean defaultValue) { AnnotationValue value = getAnnotationElementValue(elementName, am); return value != null ? Boolean.parseBoolean(value.toString()) : defaultValue; }