-
-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Encapsulate the sampling steps using OpFromGraph
#76
Encapsulate the sampling steps using OpFromGraph
#76
Conversation
OpFromGraph
OpFromGraph
fb04626
to
a5b0940
Compare
Codecov ReportBase: 97.41% // Head: 98.28% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## main #76 +/- ##
==========================================
+ Coverage 97.41% 98.28% +0.86%
==========================================
Files 9 10 +1
Lines 619 698 +79
Branches 58 62 +4
==========================================
+ Hits 603 686 +83
Misses 5 5
+ Partials 11 7 -4
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
7a2107e
to
f8b96e7
Compare
ff86900
to
e33be86
Compare
I have encapsulated the Gibbs steps and the NUTS step in import aesara
import aesara.tensor as at
srng = at.random.RandomStream(0)
X = at.matrix("X")
tau_rv = srng.halfcauchy(0, 1, name="tau")
lmbda_rv = srng.halfcauchy(0, 1, size=X.shape[1], name="lambda")
beta_rv = srng.normal(0, lmbda_rv * tau_rv, size=X.shape[1], name="beta")
a = at.scalar("a")
b = at.scalar("b")
h_rv = srng.gamma(a, b, name="h")
eta = X @ beta_rv
p = at.sigmoid(-eta)
Y_rv = srng.nbinom(h_rv, p, name="Y")
y_vv = Y_rv.clone()
y_vv.name = "y"
sample_vars = [tau_rv, lmbda_rv, beta_rv, h_rv]
sampler, initial_values = construct_sampler({Y_rv: y_vv}, srng)
h_post_step = sampler.sample_steps[h_rv]
aesara.dprint(h_post_step) Returns: DispersionGibbsKernel{inline=True}.0 [id A] 'h_posterior'
|h [id B]
|Elemwise{sigmoid,no_inplace} [id C]
| |dot [id D]
| |Elemwise{neg,no_inplace} [id E]
| | |X [id F]
| |NBRegressionGibbsKernel{inline=True} [id G] 'beta_posterior'
| |beta [id H]
| |Elemwise{mul,no_inplace} [id I]
| | |HorseshoeGibbsKernel{inline=True}.0 [id J] 'lambda_posterior'
| | | |beta [id H]
| | | |Elemwise{pow,no_inplace} [id K]
| | | | |lambda [id L]
| | | | |InplaceDimShuffle{x} [id M]
| | | | |TensorConstant{2} [id N]
| | | |Elemwise{pow,no_inplace} [id O]
| | | | |tau [id P]
| | | | |TensorConstant{2} [id Q]
| | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5A82E0>) [id R]
| | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E4D1200>) [id S]
| | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E549040>) [id T]
| | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E85AF20>) [id U]
| | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5EA120>) [id V]
| | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5EA660>) [id W]
| | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B6A19E0>) [id X]
| | |InplaceDimShuffle{x} [id Y]
| | |HorseshoeGibbsKernel{inline=True}.1 [id J] 'tau_posterior'
| |X [id F]
| |h [id B]
| |y [id Z]
| |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E549040>) [id T]
| |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E85AF20>) [id U]
| |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5C59E0>) [id BA]
| |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5C5660>) [id BB]
| |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E4F3660>) [id BC]
| |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E4F3040>) [id BD]
| |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E4D1200>) [id S]
| |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B6A5E40>) [id BE]
|a [id BF]
|Elemwise{true_div,no_inplace} [id BG]
| |TensorConstant{1.0} [id BH]
| |b [id BI]
|y [id Z]
|RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E4F3660>) [id BC]
|RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E4F3040>) [id BD]
|RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E4D1200>) [id S]
|RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E549040>) [id T]
|RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E85AF20>) [id U]
|RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5EAF20>) [id BJ]
|RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B6983C0>) [id BK]
|RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5C59E0>) [id BA]
|RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5C5660>) [id BB]
|RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B6A5E40>) [id BE] You have also probably noticed that the signature of sampler, initial_values = construct_sampler({Y_rv: y_vv}, srng)
print(sampler.sample_steps)
# {tau: tau_posterior, lambda: lambda_posterior, beta: beta_posterior, h: h_posterior}
print(sampler.samplers)
# [DispersionGibbsKernel, NBRegressionGibbsKernel, HorseshoeGibbsKernel]
print(sampler.parameters)
# {}
# A map between the `Apply` nodes and the parameters A little refactoring work, especially around the names, is necessary, but the interface feels good enough for now (note: If I added a type for the NUTS parameters I wouldn't need the |
That's awesome! We can definitely do something about all those shared RNG variables, too, and especially |
613836d
to
950a836
Compare
I streamlined the interfaces as much as possible for now, we'll add more information to these objects as the need arises. The
Can't wait to use this to construct the adaptation automatically. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great!
In this PR I explore the possibility to encapsulate the sampling steps in their own
Op
s usingOpFromGraph
. This allows to define types for the different kinds of samplers, and may allow easier manipulation of the graphs downstream as discussed in #71.Op
is created in sampler findersSamplingStep
type that subclassesOpFromGraph
, holds the sampled RVs and the sampler parametersFFBS(not currently used inconstruct_sampler
)