-
-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rewrite for sum of normal RVs #239
Conversation
590044e
to
c8ceb6d
Compare
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #239 +/- ##
==========================================
- Coverage 95.76% 95.76% -0.01%
==========================================
Files 12 13 +1
Lines 2006 2029 +23
Branches 243 246 +3
==========================================
+ Hits 1921 1943 +22
Misses 46 46
- Partials 39 40 +1
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good!
I've added some comments regarding the handling of RNG objects, but that's a more general concern that we're trying to address in all of our rewriting work, and not a statement about this specific approach/implementation.
aeppl/math_stat.py
Outdated
mu_y, sigma_y = Y_rv.owner.inputs[-2:] | ||
|
||
new_node = normal.make_node( | ||
X_rv.owner.inputs[0], # temporary rng? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a tough one we've been trying to deal with more generally. The real complications come from any "updates" that are associated with the RNG objects.
More specifically, RNG objects are usually SharedVariable
s, and those can have SharedVariable.default_update
attributes that hold onto other Variable
s (i.e. graphs representing the new value of the SharedVariable
after each call to a compiled Aesara function with this update). In the case of RNG objects created by RandomStream
s, those default updates are the RNG objects output after drawing a sample from a RandomVariable
node. In other words, the updates mechanism replaces a shared RNG object with a copy of the RNG object after a sample has been drawn using it, so the Aesara updates mechanism is emulating an in-place update of the RNG (e.g. just as rng.normal()
automatically updates the internal state of rng
in NumPy).
When performing replacements like this, it's possible that rng = X_rv.owner.inputs[0]
will have a rng.default_update
containing the original X_rv
graph, and, if someone attempted to aesara.compile
the resulting new_node
, Aesara would pick up that old graph from the re-used RNG's default_update
attribute and add it to the compiled results. We definitely don't want to have to sample some unrelated graphs just to update the RNG objects, especially when those update graphs aren't consistent with the underlying sampling process.
Anyway, I'm pointing this out because it's a general design issue and usability complication of which I'm trying to make more people aware—mostly so we can fix it/improve the usability.
In this exact case, we can probably just clone the RNG SharedVariable
s and add our own default_update
s to those (e.g. similar to how RandomStream.gen
does). Also, the graphs produced here are only supposed to be used as an intermediate representation for obtaining log-probabilities, so, as long as we don't expect people to actually compile and sample these graphs, the default updates shouldn't matter, and we can probably just clone the RNGs and remove their default updates altogether.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, there's a lot of info here and I have the general picture. Thanks for the explanation.
When you say clone the RNG SharedVariable
, do you mean:
new_rng = X_rv.owner.inputs[0].clone()
# something more needs to be done to `new_rng...
new_node = normal.make_node(rng, *other_inputs)
new_node.inputs[0].default_update = new_node
Something is certainly off because x_rv.owner.inputs[0].default_update == x_rv
yields False
.
I'm also confused because it feels like there are two RNGs here, one from X_rv
and one from Y_rv
... Or should I create a new RNG object akin to what's being done in RandomStream.gen
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you say clone the RNG
SharedVariable
, do you mean:
Yeah, with one important difference, though:
new_rng = X_rv.owner.inputs[0].clone()
# something more needs to be done to `new_rng...
new_node = normal.make_node(rng, *other_inputs)
new_rng.default_update = new_node.outputs[0]
In other words, it's the new_rng
that needs to be updated with the values of the RNGs output by new_node
.
I'm also confused because it feels like there are two RNGs here, one from
X_rv
and one fromY_rv
Which RNG we use is really just a choice. The only thing we need to consider is the user-level seeding, and, as long as we choose an RNG from one of the existing RandomVariable
s, we should maintain some consistency with the seeding. If we use RandomStream.gen
, then we can't generate a connection with the user's seeding unless we have their RandomStream
instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, because this whole RandomVariable
+ SharedVariable.default_update
situation is such a mess, I've created this: aesara-devs/aesara#1478.
c8ceb6d
to
43bc724
Compare
43bc724
to
cb2bf14
Compare
This is an elaboration of #239 (comment).
That's a very important question! Those inputs are usually Altogether, the Here's an illustration: import aesara
import aesara.tensor as at
srng = at.random.RandomStream(23092)
X_rv = srng.normal(0, 1, name="X")
aesara.dprint(X_rv, print_default_updates=True)
# normal_rv{0, (0, 0), floatX, False}.1 [id A] 'X'
# |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F44C55C1120>) [id B] <- [id A]
# |TensorConstant{[]} [id C]
# |TensorConstant{11} [id D]
# |TensorConstant{0} [id E]
# |TensorConstant{1} [id F]
#
# Default updates:
#
# normal_rv{0, (0, 0), floatX, False}.0 [id A]
# |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F44C55C1120>) [id B] <- [id A]
# |TensorConstant{[]} [id C]
# |TensorConstant{11} [id D]
# |TensorConstant{0} [id E]
# |TensorConstant{1} [id F] Notice the When this graph is In pure Python, this whole situation is similar to the following: from copy import copy
import numpy as np
rng = np.random.default_rng(2309)
def draw_normal(rng):
new_rng = copy(rng)
res = new_rng.normal(0, 1)
return new_rng, res
# Replace the old `rng` with the returned RNG
rng, res = draw_normal(rng) Now, problems arise when a This is what happens when we re-use an RNG with default updates for another graph: # The RNG from `X`
rng_X = X_rv.owner.inputs[0]
# Manually create a `RandomVariable` with a specific RNG
Y_rv = at.random.gamma(0.5, 0.5, name="Y", rng=rng_X)
# Compile the function
Y_rv_fn = aesara.function([], Y_rv)
# View the compiled graph
aesara.dprint(Y_rv_fn)
# gamma_rv{0, (0, 0), floatX, True}.1 [id A] 'Y' 1
# |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F44C55C1120>) [id B]
# |TensorConstant{[]} [id C]
# |TensorConstant{11} [id D]
# |TensorConstant{0.5} [id E]
# |TensorConstant{2.0} [id F]
# normal_rv{0, (0, 0), floatX, False}.0 [id G] 0
# |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F44C55C1120>) [id B]
# |TensorConstant{[]} [id C]
# |TensorConstant{11} [id D]
# |TensorConstant{0} [id H]
# |TensorConstant{1} [id I] As you can see, the compiled graph is sampling a gamma and normal variable, but the normal sampling is only done in order to update the RNG If we want to reuse existing RNG In AePPL, we generally don't sample the IR we produce, so this isn't an immediate problem, but it could easily creep into graphs somewhere down the line and cause real issues. As far as this rewrite is concerned, we're probably fine copying the |
4f4a30b
to
00f82b8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @brandonwillard for the thorough reply. It took me a while to have an idea of the problem that you are addressing. @rlouf I incorporated your suggestions to the unit tests
As far as this rewrite is concerned, we're probably fine copying the SharedVariable and (re)specifying the .default_update graph.
The only thing that I have yet to fully address is the RNG of the newly created Normal RV node. I gave some details as a comment in this code review
00f82b8
to
5a136dd
Compare
@brandonwillard Is this PR close to the end? Should I get started on the remaining rewrites in #238 and build on top of this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@brandonwillard Is this PR close to the end? Should I get started on the remaining rewrites in #238 and build on top of this PR?
As far as the rewrite is concerned, it looks good. We can merge it as-is, but I don't think it will be used because of the rewrite ordering and existing transform-based approach.
We'll need to change the rewrite DB setup in order to allow rewrites like this (i.e. ones that provide more/better distribution information) to come first in place of the generic transforms approach. That concern is independent of this work, though.
OK, I'm about to push a change that adds support for subtraction and changes the DB to which the rewrite is registered. I'll merge after that. |
5a136dd
to
325afb8
Compare
This should be good to merge once the tests pass. Thanks again, @larryshamalama; this was a great addition! |
Closes 1/3 of #238, for now. As I'm opening this PR, the rewrite works for scalar-valued normal RVs:
I'm happy to receive comments and then address the list below. The main to-do is to extend to matrix-valued normal RVs.
Some questions:
rng
argument in themake_node
? For now, I left it as therng
of the first RV input, but this is incorrect.EquilibriumGraphRewriter
. Also not sure about this.math_stat.py
. Would there be a better name for this file? Or should this be added to an existing file?Happy to hear any comments!