-
Notifications
You must be signed in to change notification settings - Fork 215
[WIP] tf.Variable like API for variables #179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think splitting out mutability is a good choice here. The variable can be mutated under the covers by the TF runtime (and will be, because that's what it's for), and the documentation here isn't clear that this is the case. Also an immutable variable sounds a lot like constant
and we've already got those, which these docs don't reference. I'm not clear what's gained by making the split given the only implementation is mutable, and the immutable view has a method that returns the mutable version. Could you give me more of an idea why having an unmodifiable view be the default is worthwhile?
I think using the reference variables is fine and we should probably migrate, but there are a bunch of things in the framework which gloss over the existence of such variables and we'll need to migrate those (as when I ported over the optimizers from Python I ignored all the reference ops and use the regular ones, but presumably we'll need to switch to using the reference ops).
tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java
Outdated
Show resolved
Hide resolved
tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java
Show resolved
Hide resolved
tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java
Outdated
Show resolved
Hide resolved
} | ||
cachedRead = ReadVariableOp.create(scope, handle, tType); | ||
} | ||
return cachedRead; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method has data races. In Python they don't need to worry about it, but we should at least consider what the possible behaviours are in Java.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I should use a local variable there. Do you think using AtomicReference
as well is worth it? I'm not sure how much of the rest is thread safe.
If not, I'll have to remove the method. |
The immutable (more properly read only) version is somewhat similar to Kotlin's collections. It's always backed by a mutable version, and may be casted and mutated in most cases, but the mutation APIs are hidden. So not it's immutable, just a read only view of the mutable variable, semantically. I do need to update the docs a bit to reflect that. The idea was since that generally users assigning to a variable will be an error, it should be harder to do and be made more explicit. |
Optimizers definitely will need a pass, I originally didn't do it as part of this PR as I wasn't going to use the resource variables, but now that I am I probably should. I'd like to make it a bit easier to pass in the list of variables to update on rather than using the entire graph, too, for multi-model support (i.e. GANs). One other thing that came up: I can't use initialization as a control dependency since it would be re-ran each time. Is there a way to require initialization to have been ran in the session? You will get an uninitialized error if you don't run it, but I'd like a clearer way (or ideally to run it automatically, like a control dependency, just only once per session). |
Currently somewhat blocked by tensorflow/tensorflow#46114, although if/when we add gradient registry ops we could work around it. |
Sure, but that's not how Java's collections work and I think it might be preferable not to leak Kotlin-isms into the Java API used by all the JVM languages. Doesn't python expose these operations to allow updates to epoch numbers, learning rates and similar? I think those would be common use cases we'd have in Java too.
The lack of this gradient (which presumably just passes straight through) means it's not possible to train any models which contain this variable right? |
Yeah, and it is just a passthrough.
Yeah, true, I'll refactor that. If/when proper logging is added it would be good to have a warning when variables are mutated when they aren't "supposed" to be (i.e. in the forward pass). |
Ok, well this needs to go on hold till they fix that upstream then.
How would we identify what's a forward pass vs a backward pass? Privileging the target names seems restrictive. |
|
||
private final Shape shape; | ||
private final DataType dataType; | ||
private final boolean trainable; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be mutable? In Keras people freeze and unfreeze layers all the time, so flipping this back and forth seems reasonable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not in python, so I copied that (see here). It doesn't seem that hard to do, but we'd want to have something to ensure it's not changed while a gradient tape is active in eager mode, or have some special handling for that. It depends on what that API looks like though.
I don't know how keras handles it. Maybe with the StopGradient op? It seems like they have their own trainable and non-trainable weight management outside of this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok, so the python docs make sense. The notes on the behaviour of trainable in saved model means it's still counter-intuitive as I think I'd like to store the epoch count in a checkpoint, it makes it easier to restart training, but oh well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to add the trainable variables to a hook in the graph the same way they do in Python. The registerVariable hook could check if it's trainable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could, yeah. I figured it was trivial to filter the variables list. I would think any optimizer implementations would want to take a list of variables from whatever types of models they use anyways, to allow for multiple models in the same graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, the Keras Model
collects the "trainable" variables and passes them to the optimize
function.
tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java
Show resolved
Hide resolved
Yeah. I made a PR following the C++ gradient instructions, but it doesn't work. I'm going to follow up in the mailing list but haven't had time yet.
Yeah, I don't know how feasible that is. It would probably have to be a framework level setting which we shouldn't mix into core. For eager mode, it's easy enough to check if any gradient tapes are active, but I don't know of a way to do that in graph mode. |
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
…sion) Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
…o match python Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
… fixed in tensorflow#237 PR. Signed-off-by: Ryan Nett <rnett@calpoly.edu>
This will change fairly significantly due to #237 (and parts are currently broken), so don't merge it. However, with tensorflow/tensorflow#46115, the gradient works. |
In |
|
Fixes #170.
Adds a
tf.Variable
like class, using the new resource API (see here and here). It's compatible with Eager and Graph mode, and will work better with eager gradient tapes in the future.A couple of points I'd like feedback on:
asMutableVariable()
).value()
,create()
, etc. methods that don't take aScope
parameter and use the scope the variable is created with. While these are certainly much nicer to use, they do muddle the semantics a bit, and if gradient tapes are done via Scope it would probably cause issues (since the tapes need to be notified of variable usage, afaik). I originally did it this way becauseshape()
depends onvalue()
which needs a scope, although it is possible to just return the variable's initial shape.