Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add effect handler that forces sample sites to sample the same value [feature request] #3374

Closed
BenZickel opened this issue Jun 13, 2024 · 1 comment

Comments

@BenZickel
Copy link
Contributor

BenZickel commented Jun 13, 2024

Description

When searching for models that best describe observations we often want to take a complex model
and simplify it by making several sample sites sample the same value.

This is similar to conditioning a sample site on a value (see pyro.poutine.condition) except that the value is itself sampled.

As far as I could tell there is no effect handler that can do this in Pyro.

Proposed Implementation

Create an effect handler pyro.poutine.equalizethat works as described below:

import pyro
import torch

from pyro.infer.autoguide import AutoNormal

def per_category_model(category):
    shift = pyro.param(f'{category}_shift', torch.randn(1))
    mean = pyro.sample(f'{category}_mean', pyro.distributions.Normal(0, 1))
    std = pyro.sample(f'{category}_std', pyro.distributions.LogNormal(0, 1))
    with pyro.plate(f'{category}_num_samples', 5):
        return pyro.sample(f'{category}_values', pyro.distributions.Normal(mean + shift, std))

def model(categories):
    return {category:per_category_model(category) for category in categories}

categories = ['dogs', 'cats']

def print_trace(trace):
    for name, msg in trace.nodes.items():
        if msg['type'] =='sample' or msg['type'] == 'param':
            print(f"{msg['type']} {name} = {msg['value']}")

# Run the model
pyro.set_rng_seed(20240613)
pyro.clear_param_store()
trace = pyro.poutine.trace(model).get_trace(categories)
print('')
print('Original model site values')
print('--------------------------')
print_trace(trace)

# Suggested effect handler which forces
# the sample sites 'dogs_std' and 'cats_std'
# to have the same values
equal_std_model = pyro.poutine.equalize(model, ['dogs_std', 'cats_std'])
# The effect handler can also work on parameters
equal_std_param_model = pyro.poutine.equalize(equal_std_model, '.+_shift', 'param')

# Run the updated model
pyro.set_rng_seed(20240613)
pyro.clear_param_store()
trace = pyro.poutine.trace(equal_std_param_model).get_trace(categories)
print('')
print('Updated model site values')
print('-------------------------')
print_trace(trace)

guide = AutoNormal(equal_std_param_model)
guide(categories)
print('')
print('Guide sites')
print('-----------')
print(guide().keys())
@BenZickel BenZickel changed the title Added effect handler that forces sample sites to sample the same value [feature request] Add effect handler that forces sample sites to sample the same value [feature request] Jun 13, 2024
@fritzo
Copy link
Member

fritzo commented Jul 10, 2024

Resolved by #3375

@fritzo fritzo closed this as completed Jul 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants