Skip to content

Commit

Permalink
Initialization imprvements (tensorflow#178)
Browse files Browse the repository at this point in the history
* No-op on initAdd in eager mode

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* runInit() method in session

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* add doInitialization() to Runner

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* fix javadoc

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* assume only graph or eager environments

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Remove doInit(), update javadocs

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* small fixes

Signed-off-by: Ryan Nett <rnett@calpoly.edu>
  • Loading branch information
rnett authored and JimClarke5 committed Jan 30, 2021
1 parent af1b49f commit 7732601
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,10 @@ public final class Ops {

public final SignalOps signal;

public final QuantizationOps quantization;

public final TrainOps train;

public final QuantizationOps quantization;

private final Scope scope;

private Ops(Scope scope) {
Expand All @@ -372,8 +372,8 @@ private Ops(Scope scope) {
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
quantization = new QuantizationOps(this);
train = new TrainOps(this);
quantization = new QuantizationOps(this);
}

/**
Expand Down Expand Up @@ -2755,11 +2755,10 @@ public Init init() {
*
* <p>Registered initializers are then grouped as a single unit of computation by adding
* and executing an {@link org.tensorflow.op.core.Init#create(Scope) init} operation from a graph
* session.
* session. This is a no-op if executed in an eager session.
*
* @param scope
* @param initializer
* @throws IllegalArgumentException if the execution environment in scope is not a graph
* @see org.tensorflow.op.core.Init#create(Scope) init
*/
public void initAdd(Op initializer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,19 @@ public void run(Op op) {
runner().addTarget(op.op()).run();
}


/**
* Execute the graph's initializers.
*
* <p>This method is equivalent to {@code session.run(Ops.create(session.graph).init())}.
*
*/
public void runInit(){
Runner runner = runner();
graph.initializers().forEach(runner::addTarget);
runner.run();
}

/**
* Saves the actual state of the variables of this session's graph.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,19 @@ public static Init create(Scope scope) {
*
* <p>Registered initializers are then grouped as a single unit of computation by adding
* and executing an {@link org.tensorflow.op.core.Init#create(Scope) init} operation from a graph
* session.
* session. This is a no-op if executed in an eager session.
*
* @param scope
* @param initializer
* @throws IllegalArgumentException if the execution environment in scope is not a graph
* @see org.tensorflow.op.core.Init#create(Scope) init
*/
@Endpoint(name = "initAdd")
public static void add(Scope scope, Op initializer) {
ExecutionEnvironment exEnv = scope.env();
if (!(exEnv instanceof Graph)) {
throw new IllegalArgumentException("initAdd is only supported on Graph sessions.");

if (exEnv.isGraph()) {
((Graph) exEnv).addInitializer(initializer);
}
Graph graph = (Graph) exEnv;
graph.addInitializer(initializer);
}

private Init(Operation operation) {
Expand Down

0 comments on commit 7732601

Please sign in to comment.