Skip to content

[WIP] tf.Variable like API for variables #179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from

Conversation

rnett
Copy link
Contributor

@rnett rnett commented Dec 31, 2020

Fixes #170.

Adds a tf.Variable like class, using the new resource API (see here and here). It's compatible with Eager and Graph mode, and will work better with eager gradient tapes in the future.

A couple of points I'd like feedback on:

  1. Immutability by default. Most of the time variables are only updated by the optimizer, not by any code users will write, so I think this makes sense. It's also not hard to access the mutable version (asMutableVariable()).
  2. Auto-scoping operations to the creation scope. This is the value(), create(), etc. methods that don't take a Scope parameter and use the scope the variable is created with. While these are certainly much nicer to use, they do muddle the semantics a bit, and if gradient tapes are done via Scope it would probably cause issues (since the tapes need to be notified of variable usage, afaik). I originally did it this way because shape() depends on value() which needs a scope, although it is possible to just return the variable's initial shape.

Copy link
Collaborator

@Craigacp Craigacp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think splitting out mutability is a good choice here. The variable can be mutated under the covers by the TF runtime (and will be, because that's what it's for), and the documentation here isn't clear that this is the case. Also an immutable variable sounds a lot like constant and we've already got those, which these docs don't reference. I'm not clear what's gained by making the split given the only implementation is mutable, and the immutable view has a method that returns the mutable version. Could you give me more of an idea why having an unmodifiable view be the default is worthwhile?

I think using the reference variables is fine and we should probably migrate, but there are a bunch of things in the framework which gloss over the existence of such variables and we'll need to migrate those (as when I ported over the optimizers from Python I ignored all the reference ops and use the regular ones, but presumably we'll need to switch to using the reference ops).

}
cachedRead = ReadVariableOp.create(scope, handle, tType);
}
return cachedRead;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method has data races. In Python they don't need to worry about it, but we should at least consider what the possible behaviours are in Java.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I should use a local variable there. Do you think using AtomicReference as well is worth it? I'm not sure how much of the rest is thread safe.

@rnett
Copy link
Contributor Author

rnett commented Dec 31, 2020

isValueInitialized(), which uses IsVariableInitialized, requires a ref tensor, is there any way to get this in the Java API yet?

If not, I'll have to remove the method.

@rnett
Copy link
Contributor Author

rnett commented Dec 31, 2020

The immutable (more properly read only) version is somewhat similar to Kotlin's collections. It's always backed by a mutable version, and may be casted and mutated in most cases, but the mutation APIs are hidden. So not it's immutable, just a read only view of the mutable variable, semantically. I do need to update the docs a bit to reflect that. The idea was since that generally users assigning to a variable will be an error, it should be harder to do and be made more explicit.

@rnett
Copy link
Contributor Author

rnett commented Dec 31, 2020

Optimizers definitely will need a pass, I originally didn't do it as part of this PR as I wasn't going to use the resource variables, but now that I am I probably should. I'd like to make it a bit easier to pass in the list of variables to update on rather than using the entire graph, too, for multi-model support (i.e. GANs).

One other thing that came up: I can't use initialization as a control dependency since it would be re-ran each time. Is there a way to require initialization to have been ran in the session? You will get an uninitialized error if you don't run it, but I'd like a clearer way (or ideally to run it automatically, like a control dependency, just only once per session).

@rnett rnett requested a review from Craigacp January 3, 2021 02:01
@rnett
Copy link
Contributor Author

rnett commented Jan 3, 2021

Currently somewhat blocked by tensorflow/tensorflow#46114, although if/when we add gradient registry ops we could work around it.

@Craigacp
Copy link
Collaborator

Craigacp commented Jan 3, 2021

The immutable (more properly read only) version is somewhat similar to Kotlin's collections. It's always backed by a mutable version, and may be casted and mutated in most cases, but the mutation APIs are hidden. So not it's immutable, just a read only view of the mutable variable, semantically. I do need to update the docs a bit to reflect that. The idea was since that generally users assigning to a variable will be an error, it should be harder to do and be made more explicit.

Sure, but that's not how Java's collections work and I think it might be preferable not to leak Kotlin-isms into the Java API used by all the JVM languages. Doesn't python expose these operations to allow updates to epoch numbers, learning rates and similar? I think those would be common use cases we'd have in Java too.

Currently somewhat blocked by tensorflow/tensorflow#46114, although if/when we add gradient registry ops we could work around it.

The lack of this gradient (which presumably just passes straight through) means it's not possible to train any models which contain this variable right?

@rnett
Copy link
Contributor Author

rnett commented Jan 3, 2021

The lack of this gradient (which presumably just passes straight through) means it's not possible to train any models which contain this variable right?

Yeah, and it is just a passthrough.

Doesn't python expose these operations to allow updates to epoch numbers, learning rates and similar? I think those would be common use cases we'd have in Java too.

Yeah, true, I'll refactor that. If/when proper logging is added it would be good to have a warning when variables are mutated when they aren't "supposed" to be (i.e. in the forward pass).

@Craigacp
Copy link
Collaborator

The lack of this gradient (which presumably just passes straight through) means it's not possible to train any models which contain this variable right?

Yeah, and it is just a passthrough.

Ok, well this needs to go on hold till they fix that upstream then.

Doesn't python expose these operations to allow updates to epoch numbers, learning rates and similar? I think those would be common use cases we'd have in Java too.

Yeah, true, I'll refactor that. If/when proper logging is added it would be good to have a warning when variables are mutated when they aren't "supposed" to be (i.e. in the forward pass).

How would we identify what's a forward pass vs a backward pass? Privileging the target names seems restrictive.


private final Shape shape;
private final DataType dataType;
private final boolean trainable;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be mutable? In Keras people freeze and unfreeze layers all the time, so flipping this back and forth seems reasonable.

Copy link
Contributor Author

@rnett rnett Jan 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not in python, so I copied that (see here). It doesn't seem that hard to do, but we'd want to have something to ensure it's not changed while a gradient tape is active in eager mode, or have some special handling for that. It depends on what that API looks like though.

I don't know how keras handles it. Maybe with the StopGradient op? It seems like they have their own trainable and non-trainable weight management outside of this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, so the python docs make sense. The notes on the behaviour of trainable in saved model means it's still counter-intuitive as I think I'd like to store the epoch count in a checkpoint, it makes it easier to restart training, but oh well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to add the trainable variables to a hook in the graph the same way they do in Python. The registerVariable hook could check if it's trainable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could, yeah. I figured it was trivial to filter the variables list. I would think any optimizer implementations would want to take a list of variables from whatever types of models they use anyways, to allow for multiple models in the same graph.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, the Keras Model collects the "trainable" variables and passes them to the optimize function.

@rnett
Copy link
Contributor Author

rnett commented Jan 15, 2021

Ok, well this needs to go on hold till they fix that upstream then.

Yeah. I made a PR following the C++ gradient instructions, but it doesn't work. I'm going to follow up in the mailing list but haven't had time yet.

How would we identify what's a forward pass vs a backward pass? Privileging the target names seems restrictive.

Yeah, I don't know how feasible that is. It would probably have to be a framework level setting which we shouldn't mix into core. For eager mode, it's easy enough to check if any gradient tapes are active, but I don't know of a way to do that in graph mode.

rnett added 16 commits March 12, 2021 12:12
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
…sion)

Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
…o match python

Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
… fixed in tensorflow#237 PR.

Signed-off-by: Ryan Nett <rnett@calpoly.edu>
@rnett
Copy link
Contributor Author

rnett commented Mar 12, 2021

This will change fairly significantly due to #237 (and parts are currently broken), so don't merge it. However, with tensorflow/tensorflow#46115, the gradient works.

@JimClarke5
Copy link
Contributor

In Metrics and Model, several variables are updated outside of Optimizers.

@JimClarke5
Copy link
Contributor

callbacks.LearningRateScheduler requires a way to update the learningRate in an Optimizer. The original thought was to use Placeholders for this, but if we could use a Variable we wouldn't need to do funny stuff with feeds and changing Tensor values.

@rnett rnett closed this Sep 19, 2021
@rnett rnett deleted the rn_variable branch September 19, 2021 00:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Variable class (like tf.Variable) that supports eager mode
3 participants