Skip to content

Various op-related changes #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 3, 2020
Merged

Conversation

karllessard
Copy link
Collaborator

@karllessard karllessard commented Mar 28, 2020

Here is a summary for all changes applied in this PR:

  • Rename PrimitiveOp class to RawOp to match Python name
  • Rollback tf.val to tf.constant to match Python name
  • Move PrimitiveOp.op() at the Op interface level, which converts an op to a single Operation
  • Change training optimizers and initializers to accept and return instances of Op instead of Operand
  • Extend Op from Operand
  • Add new methods in Session for executing graph initializers in one call (runInitializers())
  • Add Init operator (see comment)

CC: @dhruvrajan , @Craigacp

@karllessard
Copy link
Collaborator Author

@Craigacp , I just did a last update to change a bit the behaviour of variable initialization to leverage the usage of Ops instead of accessing directly the Graph object (which might not always be in scope when it's time to add initializers).

So now, initializer registration looks like this (with the addition of a new Init operator, see Init.java):

try (Graph g = new Graph()) {
  Ops tf = Ops.create(g);

  Variable<TInt32> x = tf.variable(tf.constant(10));
  Variable<TInt32> y = tf.variable(tf.constant(20));
  Add<TInt32> z = tf.math.add(x, y);
  tf.init();

  try (Session s = new Session(g)) {
    s.runInit();
    try (Tensor<TInt32> t = s.runner().fetch(z).run().get(0).expect(TInt32.DTYPE)) {
      assertEquals(30, t.data().getInt());
    }
  }
}

This also allows a user to give a different name to the initializer op using standard scope semantic:

....
tf.withName("myInit").init();

try (Session s = new Session(g)) {
   s.runInit("myInit");
   ...
}

so Graph.variableInitializers() have been simply removed and everything goes through Ops now. Please let me know of your thoughts on this, thanks.

scope = scope.withName(name).withControlDependencies(initializers);
return NoOp.create(scope);
public List<Op> initializers() {
return Collections.unmodifiableList(initializers);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

* }
* }</pre>
*/
public void runInit(Init initOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for these changes Karl, I think these will be super helpful! Quick thought:

Since the new Init class is a RawOp, do we gain anything by adding the runInit method, as opposed to running the RawOp in the normal fashion:

session.runner().addTarget(initOp).run()

We could even add additional run methods to Session that accept single Op/Operand/etc. arguments, and have them automatically added as targets, to achieve syntax like

session.run(initOp)

TensorFlow Python (pre 2.0) has a standard syntax for the variable initializers (for both global and local variables)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
 
    ...

Maybe we could modify the Java API to match? It may not yet be on our roadmap to distinguish between global (shared across processes) variables and local (per-process) variables, but we could perhaps try for something like:

try (Graph graph = new Graph()) {
  Ops tf = Ops.create(graph);

  ...

  try (Session session = new Session(graph)) {
    session.addTarget(tf.variablesInitializer()).run(); // or session.run(tf.variablesInitializer)

    ...

  }
}

tf.variablesInitializer() can call Graph.initializers() that you added below. To make it even simpler, we could automatically add the Assign op to a graph when tf.variable is called with an initial value, so users don't need to keep track of adding each initialization to the initializers list.

The global vs. local variable distinction may be something we want to add later on though...

What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dhruvrajan ,

We could even add additional run methods to Session that accept single Op/Operand/etc. arguments

Yes, I thought of having a session.run(Op) as well, I thought that by typing it we kind of "guide" the user to use it only for running initializers but I agree we can relax this a bit.

tf.variablesInitializer() can call Graph.initializers() that you added below.

The tf.variableInitializers() sounds pretty much like the tf.init(), no? But the problem that arise and that @Craigacp mentioned to me before is when you load a graph from disk, your graph won't have any initializers so you can't build that op at runtime, you need to find it in the graph.

That is why the tf.init() is always called before running the session. It's kind of annoying, I agree, to always remember calling this method before freezing the graph, I don't know how we can make it more obvious... I can also revert to what Adam did before, where we could retrieve the init op from the Graph instead of Ops (i.e. we would do something like session.run(graph.variableInitializers())). At least that simple run method would be a simple improvement.

To make it even simpler, we could automatically add the Assign op to a graph when tf.variable

This is already supported! Check here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already supported! Check here

Ah thanks for pointing that out, awesome!

when you load a graph from disk, your graph won't have any initializers so you can't build that op at runtime, you need to find it in the graph.

Hmm this is interesting! If you load a graph from disk, why won't the graph have any initializers? (It may have init ops in the graph, I guess, but they wouldn't be mapped to the initializers list we maintain in Java?). In this case, could we just create an Init Op that does nothing?

That is why the tf.init() is always called before running the session.

To clarify, must the Init op be created before a session is created, or just run before the first session.run? If it is the latter, we can have:

try (Session session = new Session(graph)) {
  session.run(tf.init()); // maybe rename to variablesInitializer() for consistency?
}

Brainstorming other ways to keep the syntax consistent with Python, I had a couple ideas, eager to know your thoughts @karllessard:

  1. If it really needs to go before Session.create, add the Init creation (tf.init()) call as the first step of Session.create() or a tf.session() factory method.

I can't think of a case where we would not want this Op to be added to the graph; this could also eliminate the need for the user to call tf.init()?

  1. As in TensorFlow Python keep track of individual variable initial_values and initializer ops as fields within the Variable class.

Thus instead of keeping track of a list of initializer ops on the graph, we only keep track of variables. We can extract the initializers by mapping over the variables. (See: global_variables, global_variables_initializer)

Then we get syntax and functionality super close to standard TF (I think, quite desirable)

try (Session session = tf.session()) { // automatically runs tf.init()
    session.run(tf.variablesInitializer());
}

Thanks for the detailed discussion! 😃

Copy link
Collaborator

@Craigacp Craigacp Mar 31, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loaded graph does have initializers but the Java code doesn't know about them. We could consider crawling the graph and giving initializers a privileged name to rebuild it, but we'd need to sync this up with everything in the tf ecosystem which emits GraphDef, as otherwise it have unintended consequences when loading graphs created by other language TF implementations.

The Python global_variables_initializer needs to be created in the same way that tf.init does at the moment, what's in this PR is pretty close to how TF 1 in python does it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern of naming it variableInitializers instead of just init is that I think there might be other op that must be called prior to run a session that variable assign.

For instance, in this old example I had to run the createSummaryFileWriter as well.

So maybe that name is not "right" in Python as well but they kept if for backward compatibility. Or the have more than one op ran before a session starts.

Now for calling tf.init() implicitly at session creation or explicitly... I don't have the answer right now, to be honest, it is a tough call that requires more thinking. If @Craigacp says right that the actual behaviour mimics what is done in Python, then I would be comfortable to merge it as is and maybe find other ways to improve it later?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup I agree; good idea to merge this, and update things reactively in the future.

For now, then, the two main things I see that will be different between the current implementation and the Python implementation are:

  1. In Java, we must call tf.init() before creating a session; in Python, the equivalent function can be called after a session is created.

  2. In Java, we explicitly keep track of initializers within the Graph object. In Python, initializers are held by Variable objects, and retrieved when needed.

Let's just keep these differences in mind and see if anything related comes up in the future!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking of it, for point 1., I’m gonna test again, maybe it’s ok to add an op to a graph after a session is created, so you can simply call session.run(tf.init()) if the graph is built in the same process as the session runner.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, looks like this works too!

*/
public void runInit(String initOpName) {
Operation operation = graph.operation(initOpName);
if (operation == null) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check the type of the operation too? If the user creates something that doesn't use the init mechanism but does have the name it might be nice to have some kind of warning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, it might be an interop issue with graphs created in other TF languages if it warns every time you run a graph created in Python.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhruvrajan suggestion is to relax a bit the constraint here and to rename runInit to run, which accepts any type of operation. I'm ok with this too, what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine by me. It's essentially just addTarget(String).run() then though?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a short cut to session.runner().addTarget(name).run(), yes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given this exists, should we add a convenience method for session.runner().fetch(name).run()? Also the exception still mentions an initializer operation, but this isn't required anymore.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could. But fetching is more tricky because you need to make sure to release the returned tensor, a behaviour I would like us to review as a whole at some point, wdyt?

If you are ok with this, then I'll update the exception message (thanks for catching this) and merge this PR now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's more complex, but it seems weird to me to have one without the other. Either way we can deal with it later.

dhruvrajan
dhruvrajan previously approved these changes Apr 2, 2020
Copy link
Contributor

@dhruvrajan dhruvrajan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Karl, looking forward to using this!

@karllessard
Copy link
Collaborator Author

Ok, code has been updated!


try (Session s = new Session(g)) {
s.runInit(init);
s.run(tf.init());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks awesome!

Craigacp
Craigacp previously approved these changes Apr 3, 2020
@karllessard karllessard merged commit bd51145 into tensorflow:master Apr 3, 2020
@karllessard karllessard deleted the op-changes branch April 3, 2020 20:00
@karllessard
Copy link
Collaborator Author

Thanks guys, it is merged

karllessard added a commit to tensorflow/java-models that referenced this pull request Apr 7, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants