Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SHOTerm not working in numpyro model #78

Open
tagordon opened this issue Mar 21, 2023 · 3 comments
Open

SHOTerm not working in numpyro model #78

tagordon opened this issue Mar 21, 2023 · 3 comments

Comments

@tagordon
Copy link
Collaborator

Hi @dfm,

I'm trying to use numpyro to sample a GP with the SHO kernel as follows:

from jax.config import config
config.update('jax_enable_x64', True)

import jax
import jax.numpy as jnp
from celerite2.jax import GaussianProcess, terms

import numpyro.distributions as dist
from numpyro import sample
from numpyro.infer import MCMC, NUTS

prior_sigma = 1.0

def numpyro_model(x, yerr, y=None):

    mean = sample("mean", dist.Normal(0.0, prior_sigma))
    logjitter = sample("logjitter", dist.Normal(-26, 3 * prior_sigma))

    logsigma = sample("logsigma", dist.Normal(-11, 3 * prior_sigma))
    rho = sample("rho", dist.Normal(1.0, 3 * prior_sigma))
    tau = sample("tau", dist.Normal(0.1, prior_sigma))
        
    term = terms.SHOTerm(sigma=jnp.exp(logsigma), rho=rho, tau=tau)
    gp = GaussianProcess(term, mean=mean)
    gp.compute(x, diag = yerr**2 + jnp.exp(logjitter), check_sorted=False)

    sample("obs", gp.numpyro_dist(), obs=y)
    
nuts_kernel = NUTS(numpyro_model, dense_mass=True, target_accept_prob=0.9)
mcmc = MCMC(
    nuts_kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=2,
    progress_bar=True,
)
rng_key = jax.random.PRNGKey(34923)
yerr = 1e-8
mcmc.run(rng_key, x, yerr, y=y)

and I'm getting an error with a long traceback that ends:

/usr/local/lib/python3.8/site-packages/celerite2/jax/terms.py in SHOTerm(*args, **kwargs)
    397     over = OverdampedSHOTerm(*args, **kwargs)
    398     under = UnderdampedSHOTerm(*args, **kwargs)
--> 399     if over.Q < 0.5:
    400         return over
    401     return under

    [... skipping hidden 2 frame]

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/0)>
Full traceback
---------------------------------------------------------------------------
ConcretizationTypeError                   Traceback (most recent call last)
<ipython-input-139-60f93f4ec4d5> in <module>
    1 yerr = 1e-8
----> 2 mcmc.run(rng_key, x, yerr, y=y)
    3 samples = mcmc.get_samples()

/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
  596         else:
  597             if self.chain_method == "sequential":
--> 598                 states, last_state = _laxmap(partial_map_fn, map_args)
  599             elif self.chain_method == "parallel":
  600                 states, last_state = pmap(partial_map_fn)(map_args)

/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _laxmap(f, xs)
  158     for i in range(n):
  159         x = jit(_get_value_from_index)(xs, i)
--> 160         ys.append(f(x))
  161 
  162     return tree_map(lambda *args: jnp.stack(args), *ys)

/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
  379         rng_key, init_state, init_params = init
  380         if init_state is None:
--> 381             init_state = self.sampler.init(
  382                 rng_key,
  383                 self.num_warmup,

/usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
  704                 vmap(random.split)(rng_key), 0, 1
  705             )
--> 706         init_params = self._init_state(
  707             rng_key_init_model, model_args, model_kwargs, init_params
  708         )

/usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
  650     def _init_state(self, rng_key, model_args, model_kwargs, init_params):
  651         if self._model is not None:
--> 652             init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
  653                 rng_key,
  654                 self._model,

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
  654         init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
  655     prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 656     (init_params, pe, grad), is_valid = find_valid_initial_params(
  657         rng_key,
  658         substitute(

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
  395     # Handle possible vectorization
  396     if rng_key.ndim == 1:
--> 397         (init_params, pe, z_grad), is_valid = _find_valid_params(
  398             rng_key, exit_early=True
  399         )

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
  388         # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times
  389         # even if the init_state is a valid result
--> 390         _, _, (init_params, pe, z_grad), is_valid = while_loop(
  391             cond_fn, body_fn, init_state
  392         )

/usr/local/lib/python3.8/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
  129         return val
  130     else:
--> 131         return lax.while_loop(cond_fun, body_fun, init_val)
  132 
  133 

  [... skipping hidden 9 frame]

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in body_fn(state)
  365                 z_grad = jacfwd(potential_fn)(params)
  366             else:
--> 367                 pe, z_grad = value_and_grad(potential_fn)(params)
  368             z_grad_flat = ravel_pytree(z_grad)[0]
  369             is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

  [... skipping hidden 8 frame]

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
  247     )
  248     # no param is needed for log_density computation because we already substitute
--> 249     log_joint, model_trace = log_density_(
  250         substituted_model, model_args, model_kwargs, {}
  251     )

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in log_density(model, model_args, model_kwargs, params)
   60     """
   61     model = substitute(model, data=params)
---> 62     model_trace = trace(model).get_trace(*model_args, **model_kwargs)
   63     log_joint = jnp.zeros(())
   64     for site in model_trace.values():

/usr/local/lib/python3.8/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
  169         :return: `OrderedDict` containing the execution trace.
  170         """
--> 171         self(*args, **kwargs)
  172         return self.trace
  173 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

<ipython-input-137-39666cc8f7df> in numpyro_model(x, yerr, y)
   10     tau = sample("tau", dist.Normal(0.1, prior_sigma))
   11 
---> 12     term = jTerms.SHOTerm(sigma=jnp.exp(logsigma), rho=rho, tau=tau)
   13     gp = jGP(term, mean=mean)
   14     gp.compute(x, diag = yerr**2 + jnp.exp(logjitter), check_sorted=False)

/usr/local/lib/python3.8/site-packages/celerite2/jax/terms.py in SHOTerm(*args, **kwargs)
  397     over = OverdampedSHOTerm(*args, **kwargs)
  398     under = UnderdampedSHOTerm(*args, **kwargs)
--> 399     if over.Q < 0.5:
  400         return over
  401     return under

  [... skipping hidden 2 frame]

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `bool` function. 
The error occurred while tracing the function body_fn at /usr/local/lib/python3.8/site-packages/numpyro/infer/util.py:315 for while_loop. This concrete value was not available in Python because it depends on the value of the argument state[1].

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

It runs just fine if I replace terms.SHOTerm with terms.UnderdampedSHOTerm and constrain the hyper parameters to be in the underdamped regime. Any idea what's going on here?

Thanks!

@bmorris3
Copy link

This probably falls under the "Traced value used in control flow" section of the docs for ConcretizationTypeError. JAX can't differentiate over that if.

@dfm
Copy link
Member

dfm commented Mar 22, 2023

@tagordon@bmorris3 is right, but there are some options. In the (released) version of celerite2 that you're using, the suggested approach is to directly use a OverdampedSHOTerm or UnderdampedSHOTerm, and restrict to a valid Q yourself. In the GitHub version of celerite2, this is no longer an issue, so you might be better off just installing from GitHub directly for now?

@tagordon
Copy link
Collaborator Author

Thanks @dfm and @bmorris3! I'll give the GitHub version a try.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants