Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

numpyro and blackjax samplers producing different results #144

Closed
theorashid opened this issue Oct 29, 2021 · 14 comments
Closed

numpyro and blackjax samplers producing different results #144

theorashid opened this issue Oct 29, 2021 · 14 comments
Assignees
Labels
bug Something isn't working

Comments

@theorashid
Copy link

Bug Description

I followed the use-with-numpyro notebook to get a model that works with numpyro's sampler running on blackjax. The model runs (quickly) but the values it produces are well off. It is suspected this is due to a poor choice of step size and mass matrix.

Steps/Code to Reproduce

The code requires external data, so it is best to clone the repo if the problem isn't immediately solvable. The working numpyro code is here and the attempt using blackjax is here. The numpyro model and blackjax code is also below.

def model(home_id, away_id, score1_obs=None, score2_obs=None):
    # priors
    alpha = numpyro.sample("alpha", dist.Normal(0.0, 1.0))
    sd_att = numpyro.sample(
        "sd_att",
        dist.FoldedDistribution(dist.StudentT(3.0, 0.0, 2.5)),
    )
    sd_def = numpyro.sample(
        "sd_def",
        dist.FoldedDistribution(dist.StudentT(3.0, 0.0, 2.5)),
    )

    home = numpyro.sample("home", dist.Normal(0.0, 1.0))  # home advantage

    nt = len(np.unique(home_id))

    # team-specific model parameters
    with numpyro.plate("plate_teams", nt):
        attack = numpyro.sample("attack", dist.Normal(0, sd_att))
        defend = numpyro.sample("defend", dist.Normal(0, sd_def))

    # likelihood
    theta1 = jnp.exp(alpha + home + attack[home_id] - defend[away_id])
    theta2 = jnp.exp(alpha + attack[away_id] - defend[home_id])

    with numpyro.plate("data", len(home_id)):
        numpyro.sample("s1", dist.Poisson(theta1), obs=score1_obs)
        numpyro.sample("s2", dist.Poisson(theta2), obs=score2_obs)


rng_key = random.PRNGKey(0)

# translate the model into a log-probability function
init_params, potential_fn_gen, *_ = initialize_model(
    rng_key,
    model,
    model_args=(
        train["Home_id"].values,
        train["Away_id"].values,
        train["score1"].values,
        train["score2"].values,
    ),
    dynamic_args=True,
)

logprob = lambda position: -potential_fn_gen(
    train["Home_id"].values,
    train["Away_id"].values,
    train["score1"].values,
    train["score2"].values,
)(position)

initial_position = init_params.z
initial_state = nuts.new_state(initial_position, logprob)

# run the window adaptation (warmup)
kernel_factory = lambda step_size, inverse_mass_matrix: nuts.kernel(
    logprob, step_size, inverse_mass_matrix
)

last_state, (step_size, inverse_mass_matrix), _ = stan_warmup.run(
    rng_key, kernel_factory, initial_state, 1000
)


@partial(jax.jit, static_argnums=(1, 3))
def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos


# Build the kernel using the step size and inverse mass matrix returned from the window adaptation
kernel = kernel_factory(step_size, inverse_mass_matrix)

# Sample from the posterior distribution
states, infos = inference_loop(rng_key, kernel, last_state, 100_000)

Expected Results

As an example, for the "home" parameter, should be around 0.2-0.3. The "sd_att" and "sd_def" parameters should be constrained by the model to be positive (using FoldedDistribution()).

Actual Results

The parameter values are well off:
Screenshot 2021-10-29 at 16 28 58

Versions

BlackJAX 0.2.1
numpyro 0.7.2
Python 3.8.0 | packaged by conda-forge | (default, Nov 22 2019, 19:11:19)
[Clang 9.0.0 (tags/RELEASE_900/final)]
Jax 0.2.17
Jaxlib 0.1.67

@junpenglao
Copy link
Member

junpenglao commented Oct 30, 2021

Thanks for reporting! likely it is related to #116 so assigning to @rlouf

@rlouf rlouf added the bug Something isn't working label Dec 22, 2021
@jeremiecoullon
Copy link
Contributor

Hello!
I'm just wondering if this is still an issue, or has this been fixed by recent releases?

@rlouf
Copy link
Member

rlouf commented Jun 9, 2022

@theorashid id be interested to know if this works with the latest version !

@theorashid
Copy link
Author

theorashid commented Jun 10, 2022

Hey both, so I left this stale because of the pymc.sampling_jax.sample_blackjax_nuts function which is less fiddly than using numpyro's initialize_model.

But I rewrote my model for the new blackjax API with numpyro earlier today and tested it out. There's good news and bad news:

Good news: the model is converging nicely
Bad news: the standard deviation parameters are negative. I'm not sure how this is happening because it's the same model and a working numpyro version which uses the FoldedDistribution to keep things positive. I tried replacing the priors with HalfNormals and it still persists.

This is a weird one. Any ideas?

Code is here.
Screenshot 2022-06-10 at 13 54 15

@rlouf
Copy link
Member

rlouf commented Jun 10, 2022

Yay, great news! The values are negative because numpyro transforms the variables so that their values are between -infty and +infty before sampling (NUTS works much better with this kind of values). The logpdf that numpyro returns works with these "unconstrained" variables, so this is what blackjax returns. To get values of the untransformed variablles you thus need to apply the inverse of the transformation that numpyro used, which I guess here is the absolute value.

Does that make sense ?

@rlouf
Copy link
Member

rlouf commented Jun 10, 2022

Also did you run the code with the latest version on master or the last release on Pypi ?

@theorashid
Copy link
Author

Okay great, I was wondering how numpyro was sampling a normal with negative sigma. Glad to hear it all makes sense.

I ran this using the latest pypi release – 0.7.0

blackjax 0.7.0 pypi_0 pypi
numpyro 0.9.0 pyhd8ed1ab_0 conda-forge

@rlouf
Copy link
Member

rlouf commented Jun 10, 2022

Great, thank you for rerunning the example!

@rlouf rlouf closed this as completed Jun 10, 2022
@Madhav-Kanda
Copy link

Madhav-Kanda commented Jul 3, 2023

Is there a general method to know which inverse transformation to apply to return the constrained parameter in the case of Numpyro while using Blackjax? In the case of sigma, it can be both absolute and exponential. How to decide which one it is @rlouf?

@MDCHAMP
Copy link

MDCHAMP commented Sep 3, 2023

I am also seeing a poor performance when using numpyro.initialize_model and looking at the samples I think it might also be related to some transfromation? Was there a solution to this found in the end?

@Madhav-Kanda
Copy link

No, I wasn't able to find a solution it :(

@junpenglao
Copy link
Member

For numpyro, default transformation are associated with each bounded distribution at init time. e.g., for LogNormal the default transformation is Exp:
https://github.com/pyro-ppl/numpyro/blob/ca96eca8e8e1531e71ba559ef7a8ad3b4b68cbc2/numpyro/distributions/continuous.py#L1152-L1154

@junpenglao
Copy link
Member

There should be a way to get all the transformation being applied using numpyro's effect handler, maybe @fehiepsi know?

@MDCHAMP
Copy link

MDCHAMP commented Sep 4, 2023

Hi all, managed to get what I think is a solution! In the end I made use of the util function

init_params_unconstrained, potential_fn_gen, post_proc_fun, model_trace = numpyro.infer.util.initialize_model(
    rng_key=...,
    model=...,
    model_args=...
    dynamic_args=..,
)

Then you can return a function that undoes the transformations as:

unmapper = post_proc_fun(None)

samples_constrained = unmapper(samples_unconstrained)

Hope this helps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants