Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encapsulate the sampling steps using OpFromGraph #76

Merged
merged 3 commits into from
Nov 29, 2022

Conversation

rlouf
Copy link
Member

@rlouf rlouf commented Nov 12, 2022

In this PR I explore the possibility to encapsulate the sampling steps in their own Ops using OpFromGraph. This allows to define types for the different kinds of samplers, and may allow easier manipulation of the graphs downstream as discussed in #71.

  • Gibbs sampling steps. The Op is created in sampler finders
  • Added a SamplingStep type that subclasses OpFromGraph, holds the sampled RVs and the sampler parameters
  • NUTS
  • Can I add information that allows to identify the parameters to adapt? If so how to adapt them? (parameter types?)
  • FFBS (not currently used in construct_sampler)
  • Can this be used to change the scan order of the different updates in the sampler?

@rlouf rlouf changed the title Encapsulate the sampling steps in an OpFromGraph Encapsulate the sampling steps using OpFromGraph Nov 12, 2022
@rlouf rlouf self-assigned this Nov 12, 2022
@rlouf rlouf added enhancement New feature or request refactoring A change that improves the codebase but doesn't necessarily introduce a new feature sampler steps labels Nov 12, 2022
@rlouf rlouf force-pushed the kernels-as-opfromgraphs branch 2 times, most recently from fb04626 to a5b0940 Compare November 12, 2022 20:26
@codecov
Copy link

codecov bot commented Nov 12, 2022

Codecov Report

Base: 97.41% // Head: 98.28% // Increases project coverage by +0.86% 🎉

Coverage data is based on head (950a836) compared to base (5c95102).
Patch coverage: 100.00% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #76      +/-   ##
==========================================
+ Coverage   97.41%   98.28%   +0.86%     
==========================================
  Files           9       10       +1     
  Lines         619      698      +79     
  Branches       58       62       +4     
==========================================
+ Hits          603      686      +83     
  Misses          5        5              
+ Partials       11        7       -4     
Impacted Files Coverage Δ
aemcmc/basic.py 100.00% <100.00%> (ø)
aemcmc/gibbs.py 95.43% <100.00%> (+3.45%) ⬆️
aemcmc/nuts.py 98.50% <100.00%> (+0.50%) ⬆️
aemcmc/types.py 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@rlouf rlouf force-pushed the kernels-as-opfromgraphs branch 6 times, most recently from 7a2107e to f8b96e7 Compare November 14, 2022 16:24
aemcmc/gibbs.py Outdated Show resolved Hide resolved
@rlouf rlouf force-pushed the kernels-as-opfromgraphs branch 9 times, most recently from ff86900 to e33be86 Compare November 22, 2022 14:07
@rlouf
Copy link
Member Author

rlouf commented Nov 28, 2022

I have encapsulated the Gibbs steps and the NUTS step in OpFromGraph. The following code:

import aesara
import aesara.tensor as at


srng = at.random.RandomStream(0)

X = at.matrix("X")

tau_rv = srng.halfcauchy(0, 1, name="tau")
lmbda_rv = srng.halfcauchy(0, 1, size=X.shape[1], name="lambda")
beta_rv = srng.normal(0, lmbda_rv * tau_rv, size=X.shape[1], name="beta")

a = at.scalar("a")
b = at.scalar("b")
h_rv = srng.gamma(a, b, name="h")

eta = X @ beta_rv
p = at.sigmoid(-eta)
Y_rv = srng.nbinom(h_rv, p, name="Y")

y_vv = Y_rv.clone()
y_vv.name = "y"

sample_vars = [tau_rv, lmbda_rv, beta_rv, h_rv]
sampler, initial_values = construct_sampler({Y_rv: y_vv}, srng)
h_post_step = sampler.sample_steps[h_rv]
aesara.dprint(h_post_step)

Returns:

DispersionGibbsKernel{inline=True}.0 [id A] 'h_posterior'
 |h [id B]
 |Elemwise{sigmoid,no_inplace} [id C]
 | |dot [id D]
 |   |Elemwise{neg,no_inplace} [id E]
 |   | |X [id F]
 |   |NBRegressionGibbsKernel{inline=True} [id G] 'beta_posterior'
 |     |beta [id H]
 |     |Elemwise{mul,no_inplace} [id I]
 |     | |HorseshoeGibbsKernel{inline=True}.0 [id J] 'lambda_posterior'
 |     | | |beta [id H]
 |     | | |Elemwise{pow,no_inplace} [id K]
 |     | | | |lambda [id L]
 |     | | | |InplaceDimShuffle{x} [id M]
 |     | | |   |TensorConstant{2} [id N]
 |     | | |Elemwise{pow,no_inplace} [id O]
 |     | | | |tau [id P]
 |     | | | |TensorConstant{2} [id Q]
 |     | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5A82E0>) [id R]
 |     | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E4D1200>) [id S]
 |     | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E549040>) [id T]
 |     | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E85AF20>) [id U]
 |     | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5EA120>) [id V]
 |     | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5EA660>) [id W]
 |     | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B6A19E0>) [id X]
 |     | |InplaceDimShuffle{x} [id Y]
 |     |   |HorseshoeGibbsKernel{inline=True}.1 [id J] 'tau_posterior'
 |     |X [id F]
 |     |h [id B]
 |     |y [id Z]
 |     |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E549040>) [id T]
 |     |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E85AF20>) [id U]
 |     |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5C59E0>) [id BA]
 |     |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5C5660>) [id BB]
 |     |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E4F3660>) [id BC]
 |     |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E4F3040>) [id BD]
 |     |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E4D1200>) [id S]
 |     |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B6A5E40>) [id BE]
 |a [id BF]
 |Elemwise{true_div,no_inplace} [id BG]
 | |TensorConstant{1.0} [id BH]
 | |b [id BI]
 |y [id Z]
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E4F3660>) [id BC]
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E4F3040>) [id BD]
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E4D1200>) [id S]
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E549040>) [id T]
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640E85AF20>) [id U]
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5EAF20>) [id BJ]
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B6983C0>) [id BK]
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5C59E0>) [id BA]
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B5C5660>) [id BB]
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F640B6A5E40>) [id BE]

You have also probably noticed that the signature of construct_sample is different, indeed sample_step, updates and parameters are now encapsulated in a datalclass:

sampler, initial_values = construct_sampler({Y_rv: y_vv}, srng)

print(sampler.sample_steps)
# {tau: tau_posterior, lambda: lambda_posterior, beta: beta_posterior, h: h_posterior}

print(sampler.samplers)
# [DispersionGibbsKernel, NBRegressionGibbsKernel, HorseshoeGibbsKernel]

print(sampler.parameters)
# {}
# A map between the `Apply` nodes  and the parameters

A little refactoring work, especially around the names, is necessary, but the interface feels good enough for now (note: If I added a type for the NUTS parameters I wouldn't need the samplers.parameters map).

@brandonwillard
Copy link
Member

brandonwillard commented Nov 28, 2022

I have encapsulated the Gibbs steps and the NUTS step in OpFromGraph. The following code:

That's awesome! We can definitely do something about all those shared RNG variables, too, and especially OpFromGraph's ability to handle updates (i.e. aesara-devs/aesara#1316). (N.B. aesara-devs/aesara#1306 was a precursor PR from a branch in which I implemented aesara-devs/aesara#1316. I'll push that shortly.)

@rlouf
Copy link
Member Author

rlouf commented Nov 29, 2022

I streamlined the interfaces as much as possible for now, we'll add more information to these objects as the need arises. The Sampler type holds all the information that was returned before by construct_sampler with two differences:

  • parameters is now a map from their SamplingSteps to their parameters;
  • stages is a map from SamplingSteps to the (original) RVs whose value they update

Can't wait to use this to construct the adaptation automatically.

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request refactoring A change that improves the codebase but doesn't necessarily introduce a new feature sampler steps
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants