You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When searching for models that best describe observations we often want to take a complex model
and simplify it by making several sample sites sample the same value.
This is similar to conditioning a sample site on a value (see pyro.poutine.condition) except that the value is itself sampled.
As far as I could tell there is no effect handler that can do this in Pyro.
Proposed Implementation
Create an effect handler pyro.poutine.equalizethat works as described below:
import pyro
import torch
from pyro.infer.autoguide import AutoNormal
def per_category_model(category):
shift = pyro.param(f'{category}_shift', torch.randn(1))
mean = pyro.sample(f'{category}_mean', pyro.distributions.Normal(0, 1))
std = pyro.sample(f'{category}_std', pyro.distributions.LogNormal(0, 1))
with pyro.plate(f'{category}_num_samples', 5):
return pyro.sample(f'{category}_values', pyro.distributions.Normal(mean + shift, std))
def model(categories):
return {category:per_category_model(category) for category in categories}
categories = ['dogs', 'cats']
def print_trace(trace):
for name, msg in trace.nodes.items():
if msg['type'] =='sample' or msg['type'] == 'param':
print(f"{msg['type']} {name} = {msg['value']}")
# Run the model
pyro.set_rng_seed(20240613)
pyro.clear_param_store()
trace = pyro.poutine.trace(model).get_trace(categories)
print('')
print('Original model site values')
print('--------------------------')
print_trace(trace)
# Suggested effect handler which forces
# the sample sites 'dogs_std' and 'cats_std'
# to have the same values
equal_std_model = pyro.poutine.equalize(model, ['dogs_std', 'cats_std'])
# The effect handler can also work on parameters
equal_std_param_model = pyro.poutine.equalize(equal_std_model, '.+_shift', 'param')
# Run the updated model
pyro.set_rng_seed(20240613)
pyro.clear_param_store()
trace = pyro.poutine.trace(equal_std_param_model).get_trace(categories)
print('')
print('Updated model site values')
print('-------------------------')
print_trace(trace)
guide = AutoNormal(equal_std_param_model)
guide(categories)
print('')
print('Guide sites')
print('-----------')
print(guide().keys())
The text was updated successfully, but these errors were encountered:
BenZickel
changed the title
Added effect handler that forces sample sites to sample the same value [feature request]
Add effect handler that forces sample sites to sample the same value [feature request]
Jun 13, 2024
Description
When searching for models that best describe observations we often want to take a complex model
and simplify it by making several sample sites sample the same value.
This is similar to conditioning a sample site on a value (see
pyro.poutine.condition
) except that the value is itself sampled.As far as I could tell there is no effect handler that can do this in Pyro.
Proposed Implementation
Create an effect handler
pyro.poutine.equalize
that works as described below:The text was updated successfully, but these errors were encountered: