Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in model using CircularReparam when trying to use Predictive #1846

Closed
tomwallis opened this issue Aug 7, 2024 · 1 comment · Fixed by #1856
Closed

Error in model using CircularReparam when trying to use Predictive #1846

tomwallis opened this issue Aug 7, 2024 · 1 comment · Fixed by #1856

Comments

@tomwallis
Copy link
Contributor

Hello,

I'm trying to use the Predictive class with a model whose response variable is circular. I get a deep NotImplementedError when I try to use Predictive with the (recommended) circular reparameterization, but not when I don't use this.

Is this error related to not being able to use reparameterization on the observed variable (likelihood)? Or something specific to circular reparameterization? Any help much appreciated!

Minimal example:

import numpyro
import jax.numpy as jnp
from jax import random
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.handlers import reparam
from numpyro.infer.reparam import CircularReparam

x = jnp.linspace(-jnp.pi, jnp.pi)
y = x + random.normal(random.key(234), shape=x.shape)  # acknowledged that this will be out-of-bounds

def model(x, y=None):
    b = numpyro.sample("b", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    kappa = 1 / sigma**2
    numpyro.sample(
        "obs",
        dist.VonMises(loc=b * x, concentration=kappa),
        obs=y,
    )


reparam_model = reparam(model, config={"obs": CircularReparam()})

# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.key(4159)
rng_key, rng_key_ = random.split(rng_key)

# Run NUTS.
kernel = NUTS(reparam_model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc.run(rng_key_, x=x, y=y)
mcmc.print_summary()
samples_1 = mcmc.get_samples()

# generate posterior predictions on fitted values
rng_key, rng_key_ = random.split(rng_key)
predictive = Predictive(reparam_model, samples_1)
predictions = predictive(rng_key_, x=x)["obs"]

Error trace:

{
	"name": "NotImplementedError",
	"message": "",
	"stack": "---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[78], line 4
      2 rng_key, rng_key_ = random.split(rng_key)
      3 predictive = Predictive(reparam_model, samples_1)
----> 4 predictions = predictive(rng_key_, x=x)[\"obs\"]
      5 df[\"Mean Predictions\"] = jnp.mean(predictions, axis=0)
      6 df.head()

File .../python3.12/site-packages/numpyro/infer/util.py:1037, in Predictive.__call__(self, rng_key, *args, **kwargs)
   1027 \"\"\"
   1028 Returns dict of samples from the predictive distribution. By default, only sample sites not
   1029 contained in `posterior_samples` are returned. This can be modified by changing the
   (...)
   1034 :param kwargs: model kwargs.
   1035 \"\"\"
   1036 if self.batch_ndims == 0 or self.params == {} or self.guide is None:
-> 1037     return self._call_with_params(rng_key, self.params, args, kwargs)
   1038 elif self.batch_ndims == 1:  # batch over parameters
   1039     batch_size = jnp.shape(jax.tree.flatten(self.params)[0][0])[0]

File .../python3.12/site-packages/numpyro/infer/util.py:1013, in Predictive._call_with_params(self, rng_key, params, args, kwargs)
   1001     posterior_samples = _predictive(
   1002         guide_rng_key,
   1003         guide,
   (...)
   1010         exclude_deterministic=self.exclude_deterministic,
   1011     )
   1012 model = substitute(self.model, self.params)
-> 1013 return _predictive(
   1014     rng_key,
   1015     model,
   1016     posterior_samples,
   1017     self._batch_shape,
   1018     return_sites=self.return_sites,
   1019     infer_discrete=self.infer_discrete,
   1020     parallel=self.parallel,
   1021     model_args=args,
   1022     model_kwargs=kwargs,
   1023     exclude_deterministic=self.exclude_deterministic,
   1024 )

File .../python3.12/site-packages/numpyro/infer/util.py:846, in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, exclude_deterministic, model_args, model_kwargs)
    844 rng_key = rng_key.reshape(batch_shape + key_shape)
    845 chunk_size = num_samples if parallel else 1
--> 846 return soft_vmap(
    847     single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
    848 )

File .../python3.12/site-packages/numpyro/util.py:452, in soft_vmap(fn, xs, batch_ndims, chunk_size)
    446     xs = jax.tree.map(
    447         lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]),
    448         xs,
    449     )
    450     fn = vmap(fn)
--> 452 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
    453 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
    454 ys = jax.tree.map(
    455     lambda y: jnp.reshape(
    456         y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:]
    457     )[:batch_size],
    458     ys,
    459 )

    [... skipping hidden 12 frame]

File .../python3.12/site-packages/numpyro/infer/util.py:819, in _predictive.<locals>.single_prediction(val)
    810         return (
    811             samples.get(msg[\"name\"]) if msg[\"type\"] != \"deterministic\" else None
    812         )
    814     substituted_model = (
    815         substitute(masked_model, substitute_fn=_samples_wo_deterministic)
    816         if exclude_deterministic
    817         else substitute(masked_model, samples)
    818     )
--> 819     model_trace = trace(seed(substituted_model, rng_key)).get_trace(
    820         *model_args, **model_kwargs
    821     )
    822     pred_samples = {name: site[\"value\"] for name, site in model_trace.items()}
    824 if return_sites is not None:

File .../python3.12/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     \"\"\"
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     \"\"\"
--> 171     self(*args, **kwargs)
    172     return self.trace

File .../python3.12/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File .../python3.12/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 105 (3 times)]

File .../python3.12/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

Cell In[72], line 5, in model(x, y)
      3 sigma = numpyro.sample(\"sigma\", dist.Exponential(1))
      4 kappa = 1 / sigma**2
----> 5 numpyro.sample(
      6     \"obs\",
      7     dist.VonMises(loc=b * x, concentration=kappa),
      8     obs=y,
      9 )

File .../python3.12/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
    207 initial_msg = {
    208     \"type\": \"sample\",
    209     \"name\": name,
   (...)
    218     \"infer\": {} if infer is None else infer,
    219 }
    221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
    223 return msg[\"value\"]

File .../python3.12/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
     45 pointer = 0
     46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47     handler.process_message(msg)
     48     # When a Messenger sets the \"stop\" field of a message,
     49     # it prevents any Messengers above it on the stack from being applied.
     50     if msg.get(\"stop\"):

File .../python3.12/site-packages/numpyro/handlers.py:583, in reparam.process_message(self, msg)
    580 if reparam is None:
    581     return
--> 583 new_fn, value = reparam(msg[\"name\"], msg[\"fn\"], msg[\"value\"])
    585 if value is not None:
    586     if new_fn is None:

File .../python3.12/site-packages/numpyro/infer/reparam.py:344, in CircularReparam.__call__(self, name, fn, obs)
    342 # Draw parameter-free noise.
    343 new_fn = dist.ImproperUniform(constraints.real, fn.batch_shape, fn.event_shape)
--> 344 value = numpyro.sample(
    345     f\"{name}_unwrapped\",
    346     new_fn,
    347     obs=obs,
    348 )
    350 # Differentiably transform.
    351 value = jnp.remainder(value + math.pi, 2 * math.pi) - math.pi

File .../python3.12/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
    207 initial_msg = {
    208     \"type\": \"sample\",
    209     \"name\": name,
   (...)
    218     \"infer\": {} if infer is None else infer,
    219 }
    221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
    223 return msg[\"value\"]

File .../python3.12/site-packages/numpyro/primitives.py:53, in apply_stack(msg)
     50     if msg.get(\"stop\"):
     51         break
---> 53 default_process_message(msg)
     55 # A Messenger that sets msg[\"stop\"] == True also prevents application
     56 # of postprocess_message by Messengers above it on the stack
     57 # via the pointer variable from the process_message loop
     58 for handler in _PYRO_STACK[-pointer - 1 :]:

File .../python3.12/site-packages/numpyro/primitives.py:24, in default_process_message(msg)
     22 if msg[\"value\"] is None:
     23     if msg[\"type\"] == \"sample\":
---> 24         msg[\"value\"], msg[\"intermediates\"] = msg[\"fn\"](
     25             *msg[\"args\"], sample_intermediates=True, **msg[\"kwargs\"]
     26         )
     27     else:
     28         msg[\"value\"] = msg[\"fn\"](*msg[\"args\"], **msg[\"kwargs\"])

File .../python3.12/site-packages/numpyro/distributions/distribution.py:369, in Distribution.__call__(self, *args, **kwargs)
    367 sample_intermediates = kwargs.pop(\"sample_intermediates\", False)
    368 if sample_intermediates:
--> 369     return self.sample_with_intermediates(key, *args, **kwargs)
    370 return self.sample(key, *args, **kwargs)

File .../python3.12/site-packages/numpyro/distributions/distribution.py:327, in Distribution.sample_with_intermediates(self, key, sample_shape)
    317 def sample_with_intermediates(self, key, sample_shape=()):
    318     \"\"\"
    319     Same as ``sample`` except that any intermediate computations are
    320     returned (useful for `TransformedDistribution`).
   (...)
    325     :rtype: numpy.ndarray
    326     \"\"\"
--> 327     return self.sample(key, sample_shape=sample_shape), []

File .../python3.12/site-packages/numpyro/distributions/distribution.py:909, in MaskedDistribution.sample(self, key, sample_shape)
    908 def sample(self, key, sample_shape=()):
--> 909     return self.base_dist(rng_key=key, sample_shape=sample_shape)

File .../python3.12/site-packages/numpyro/distributions/distribution.py:370, in Distribution.__call__(self, *args, **kwargs)
    368 if sample_intermediates:
    369     return self.sample_with_intermediates(key, *args, **kwargs)
--> 370 return self.sample(key, *args, **kwargs)

File .../python3.12/site-packages/numpyro/distributions/distribution.py:315, in Distribution.sample(self, key, sample_shape)
    303 def sample(self, key, sample_shape=()):
    304     \"\"\"
    305     Returns a sample from the distribution having shape given by
    306     `sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty,
   (...)
    313     :rtype: numpy.ndarray
    314     \"\"\"
--> 315     raise NotImplementedError

NotImplementedError: "
}

Note there's no error if you replace reparam_model with model above.

Thanks in advance for input!

@fehiepsi
Copy link
Member

fehiepsi commented Aug 7, 2024

Hmm, I think we dont need to reparam the likelihood. We should add an assertion, like other reparam, to disallow that usage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants