You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ValueError Traceback (most recent call last)
<ipython-input-1-fc7fe60e083b> in <module>
10
11 with numpyro.handlers.seed(rng_seed=54):
---> 12 print(model())
<ipython-input-1-fc7fe60e083b> in model()
6 def model():
7 probs = jnp.array([0.5, 0.5, 0.])
----> 8 c = numpyro.sample('c', dist.Categorical(probs=probs))
9 return c
10
~/miniconda3/envs/numpyro_test/lib/python3.8/site-packages/numpyro/distributions/discrete.py in Categorical(probs, logits, validate_args)
348 def Categorical(probs=None, logits=None, validate_args=None):
349 if probs is not None:
--> 350 return CategoricalProbs(probs, validate_args=validate_args)
351 elif logits is not None:
352 return CategoricalLogits(logits, validate_args=validate_args)
~/miniconda3/envs/numpyro_test/lib/python3.8/site-packages/numpyro/distributions/discrete.py in __init__(self, probs, validate_args)
265 raise ValueError("`probs` parameter must be at least one-dimensional.")
266 self.probs = probs
--> 267 super(CategoricalProbs, self).__init__(batch_shape=jnp.shape(self.probs)[:-1],
268 validate_args=validate_args)
269
~/miniconda3/envs/numpyro_test/lib/python3.8/site-packages/numpyro/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
142 if not_jax_tracer(is_valid):
143 if not is_valid:
--> 144 raise ValueError("The parameter {} has invalid values".format(param))
145 super(Distribution, self).__init__()
146
ValueError: The parameter probs has invalid values
I think the problem is caused by the validation because If I restart my kernel and comment the line numpyro.enable_validation(True) the code will run without problem. It will print 0 in my case.
If I write a similar code in Pyro with the validation enabled, I do not get an error.
I am getting an error when I try to run the following code. The code just sample from a categorical distribution using the defined probabilities.
I think the problem is caused by the validation because If I restart my kernel and comment the line
numpyro.enable_validation(True)
the code will run without problem. It will print 0 in my case.If I write a similar code in Pyro with the validation enabled, I do not get an error.
I am using Python 3.8.5, Pyro 1.4.0 and NumPyro 0.3.0 with Ubuntu. Happy to help with what I can.
The text was updated successfully, but these errors were encountered: