You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to use the Predictive class with a model whose response variable is circular. I get a deep NotImplementedError when I try to use Predictive with the (recommended) circular reparameterization, but not when I don't use this.
Is this error related to not being able to use reparameterization on the observed variable (likelihood)? Or something specific to circular reparameterization? Any help much appreciated!
Minimal example:
importnumpyroimportjax.numpyasjnpfromjaximportrandomimportnumpyro.distributionsasdistfromnumpyro.inferimportMCMC, NUTS, Predictivefromnumpyro.handlersimportreparamfromnumpyro.infer.reparamimportCircularReparamx=jnp.linspace(-jnp.pi, jnp.pi)
y=x+random.normal(random.key(234), shape=x.shape) # acknowledged that this will be out-of-boundsdefmodel(x, y=None):
b=numpyro.sample("b", dist.Normal(0, 1))
sigma=numpyro.sample("sigma", dist.Exponential(1))
kappa=1/sigma**2numpyro.sample(
"obs",
dist.VonMises(loc=b*x, concentration=kappa),
obs=y,
)
reparam_model=reparam(model, config={"obs": CircularReparam()})
# Start from this source of randomness. We will split keys for subsequent operations.rng_key=random.key(4159)
rng_key, rng_key_=random.split(rng_key)
# Run NUTS.kernel=NUTS(reparam_model)
mcmc=MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc.run(rng_key_, x=x, y=y)
mcmc.print_summary()
samples_1=mcmc.get_samples()
# generate posterior predictions on fitted valuesrng_key, rng_key_=random.split(rng_key)
predictive=Predictive(reparam_model, samples_1)
predictions=predictive(rng_key_, x=x)["obs"]
Error trace:
{
"name": "NotImplementedError",
"message": "",
"stack": "---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[78], line 4
2 rng_key, rng_key_ = random.split(rng_key)
3 predictive = Predictive(reparam_model, samples_1)
----> 4 predictions = predictive(rng_key_, x=x)[\"obs\"]
5 df[\"Mean Predictions\"] = jnp.mean(predictions, axis=0)
6 df.head()
File .../python3.12/site-packages/numpyro/infer/util.py:1037, in Predictive.__call__(self, rng_key, *args, **kwargs)
1027 \"\"\"
1028 Returns dict of samples from the predictive distribution. By default, only sample sites not
1029 contained in `posterior_samples` are returned. This can be modified by changing the
(...)
1034 :param kwargs: model kwargs.
1035 \"\"\"
1036 if self.batch_ndims == 0 or self.params == {} or self.guide is None:
-> 1037 return self._call_with_params(rng_key, self.params, args, kwargs)
1038 elif self.batch_ndims == 1: # batch over parameters
1039 batch_size = jnp.shape(jax.tree.flatten(self.params)[0][0])[0]
File .../python3.12/site-packages/numpyro/infer/util.py:1013, in Predictive._call_with_params(self, rng_key, params, args, kwargs)
1001 posterior_samples = _predictive(
1002 guide_rng_key,
1003 guide,
(...)
1010 exclude_deterministic=self.exclude_deterministic,
1011 )
1012 model = substitute(self.model, self.params)
-> 1013 return _predictive(
1014 rng_key,
1015 model,
1016 posterior_samples,
1017 self._batch_shape,
1018 return_sites=self.return_sites,
1019 infer_discrete=self.infer_discrete,
1020 parallel=self.parallel,
1021 model_args=args,
1022 model_kwargs=kwargs,
1023 exclude_deterministic=self.exclude_deterministic,
1024 )
File .../python3.12/site-packages/numpyro/infer/util.py:846, in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, exclude_deterministic, model_args, model_kwargs)
844 rng_key = rng_key.reshape(batch_shape + key_shape)
845 chunk_size = num_samples if parallel else 1
--> 846 return soft_vmap(
847 single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
848 )
File .../python3.12/site-packages/numpyro/util.py:452, in soft_vmap(fn, xs, batch_ndims, chunk_size)
446 xs = jax.tree.map(
447 lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]),
448 xs,
449 )
450 fn = vmap(fn)
--> 452 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
453 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
454 ys = jax.tree.map(
455 lambda y: jnp.reshape(
456 y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:]
457 )[:batch_size],
458 ys,
459 )
[... skipping hidden 12 frame]
File .../python3.12/site-packages/numpyro/infer/util.py:819, in _predictive.<locals>.single_prediction(val)
810 return (
811 samples.get(msg[\"name\"]) if msg[\"type\"] != \"deterministic\" else None
812 )
814 substituted_model = (
815 substitute(masked_model, substitute_fn=_samples_wo_deterministic)
816 if exclude_deterministic
817 else substitute(masked_model, samples)
818 )
--> 819 model_trace = trace(seed(substituted_model, rng_key)).get_trace(
820 *model_args, **model_kwargs
821 )
822 pred_samples = {name: site[\"value\"] for name, site in model_trace.items()}
824 if return_sites is not None:
File .../python3.12/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 \"\"\"
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: `OrderedDict` containing the execution trace.
170 \"\"\"
--> 171 self(*args, **kwargs)
172 return self.trace
File .../python3.12/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File .../python3.12/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
[... skipping similar frames: Messenger.__call__ at line 105 (3 times)]
File .../python3.12/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
Cell In[72], line 5, in model(x, y)
3 sigma = numpyro.sample(\"sigma\", dist.Exponential(1))
4 kappa = 1 / sigma**2
----> 5 numpyro.sample(
6 \"obs\",
7 dist.VonMises(loc=b * x, concentration=kappa),
8 obs=y,
9 )
File .../python3.12/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
207 initial_msg = {
208 \"type\": \"sample\",
209 \"name\": name,
(...)
218 \"infer\": {} if infer is None else infer,
219 }
221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
223 return msg[\"value\"]
File .../python3.12/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
45 pointer = 0
46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47 handler.process_message(msg)
48 # When a Messenger sets the \"stop\" field of a message,
49 # it prevents any Messengers above it on the stack from being applied.
50 if msg.get(\"stop\"):
File .../python3.12/site-packages/numpyro/handlers.py:583, in reparam.process_message(self, msg)
580 if reparam is None:
581 return
--> 583 new_fn, value = reparam(msg[\"name\"], msg[\"fn\"], msg[\"value\"])
585 if value is not None:
586 if new_fn is None:
File .../python3.12/site-packages/numpyro/infer/reparam.py:344, in CircularReparam.__call__(self, name, fn, obs)
342 # Draw parameter-free noise.
343 new_fn = dist.ImproperUniform(constraints.real, fn.batch_shape, fn.event_shape)
--> 344 value = numpyro.sample(
345 f\"{name}_unwrapped\",
346 new_fn,
347 obs=obs,
348 )
350 # Differentiably transform.
351 value = jnp.remainder(value + math.pi, 2 * math.pi) - math.pi
File .../python3.12/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
207 initial_msg = {
208 \"type\": \"sample\",
209 \"name\": name,
(...)
218 \"infer\": {} if infer is None else infer,
219 }
221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
223 return msg[\"value\"]
File .../python3.12/site-packages/numpyro/primitives.py:53, in apply_stack(msg)
50 if msg.get(\"stop\"):
51 break
---> 53 default_process_message(msg)
55 # A Messenger that sets msg[\"stop\"] == True also prevents application
56 # of postprocess_message by Messengers above it on the stack
57 # via the pointer variable from the process_message loop
58 for handler in _PYRO_STACK[-pointer - 1 :]:
File .../python3.12/site-packages/numpyro/primitives.py:24, in default_process_message(msg)
22 if msg[\"value\"] is None:
23 if msg[\"type\"] == \"sample\":
---> 24 msg[\"value\"], msg[\"intermediates\"] = msg[\"fn\"](
25 *msg[\"args\"], sample_intermediates=True, **msg[\"kwargs\"]
26 )
27 else:
28 msg[\"value\"] = msg[\"fn\"](*msg[\"args\"], **msg[\"kwargs\"])
File .../python3.12/site-packages/numpyro/distributions/distribution.py:369, in Distribution.__call__(self, *args, **kwargs)
367 sample_intermediates = kwargs.pop(\"sample_intermediates\", False)
368 if sample_intermediates:
--> 369 return self.sample_with_intermediates(key, *args, **kwargs)
370 return self.sample(key, *args, **kwargs)
File .../python3.12/site-packages/numpyro/distributions/distribution.py:327, in Distribution.sample_with_intermediates(self, key, sample_shape)
317 def sample_with_intermediates(self, key, sample_shape=()):
318 \"\"\"
319 Same as ``sample`` except that any intermediate computations are
320 returned (useful for `TransformedDistribution`).
(...)
325 :rtype: numpy.ndarray
326 \"\"\"
--> 327 return self.sample(key, sample_shape=sample_shape), []
File .../python3.12/site-packages/numpyro/distributions/distribution.py:909, in MaskedDistribution.sample(self, key, sample_shape)
908 def sample(self, key, sample_shape=()):
--> 909 return self.base_dist(rng_key=key, sample_shape=sample_shape)
File .../python3.12/site-packages/numpyro/distributions/distribution.py:370, in Distribution.__call__(self, *args, **kwargs)
368 if sample_intermediates:
369 return self.sample_with_intermediates(key, *args, **kwargs)
--> 370 return self.sample(key, *args, **kwargs)
File .../python3.12/site-packages/numpyro/distributions/distribution.py:315, in Distribution.sample(self, key, sample_shape)
303 def sample(self, key, sample_shape=()):
304 \"\"\"
305 Returns a sample from the distribution having shape given by
306 `sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty,
(...)
313 :rtype: numpy.ndarray
314 \"\"\"
--> 315 raise NotImplementedError
NotImplementedError: "
}
Note there's no error if you replace reparam_model with model above.
Thanks in advance for input!
The text was updated successfully, but these errors were encountered:
Hello,
I'm trying to use the
Predictive
class with a model whose response variable is circular. I get a deepNotImplementedError
when I try to usePredictive
with the (recommended) circular reparameterization, but not when I don't use this.Is this error related to not being able to use reparameterization on the observed variable (likelihood)? Or something specific to circular reparameterization? Any help much appreciated!
Minimal example:
Error trace:
Note there's no error if you replace
reparam_model
withmodel
above.Thanks in advance for input!
The text was updated successfully, but these errors were encountered: