Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible error in the validation of a Categorical distribution #736

Closed
RaulPL opened this issue Sep 11, 2020 · 3 comments · Fixed by #737
Closed

Possible error in the validation of a Categorical distribution #736

RaulPL opened this issue Sep 11, 2020 · 3 comments · Fixed by #737

Comments

@RaulPL
Copy link
Contributor

RaulPL commented Sep 11, 2020

I am getting an error when I try to run the following code. The code just sample from a categorical distribution using the defined probabilities.

import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp
numpyro.enable_validation(True)


def model():
    probs = jnp.array([0.5, 0.5, 0.])
    c = numpyro.sample('c', dist.Categorical(probs=probs))
    return c

with numpyro.handlers.seed(rng_seed=54):
    print(model())
ValueError                                Traceback (most recent call last)
<ipython-input-1-fc7fe60e083b> in <module>
     10 
     11 with numpyro.handlers.seed(rng_seed=54):
---> 12     print(model())

<ipython-input-1-fc7fe60e083b> in model()
      6 def model():
      7     probs = jnp.array([0.5, 0.5, 0.])
----> 8     c = numpyro.sample('c', dist.Categorical(probs=probs))
      9     return c
     10 

~/miniconda3/envs/numpyro_test/lib/python3.8/site-packages/numpyro/distributions/discrete.py in Categorical(probs, logits, validate_args)
    348 def Categorical(probs=None, logits=None, validate_args=None):
    349     if probs is not None:
--> 350         return CategoricalProbs(probs, validate_args=validate_args)
    351     elif logits is not None:
    352         return CategoricalLogits(logits, validate_args=validate_args)

~/miniconda3/envs/numpyro_test/lib/python3.8/site-packages/numpyro/distributions/discrete.py in __init__(self, probs, validate_args)
    265             raise ValueError("`probs` parameter must be at least one-dimensional.")
    266         self.probs = probs
--> 267         super(CategoricalProbs, self).__init__(batch_shape=jnp.shape(self.probs)[:-1],
    268                                                validate_args=validate_args)
    269 

~/miniconda3/envs/numpyro_test/lib/python3.8/site-packages/numpyro/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
    142                 if not_jax_tracer(is_valid):
    143                     if not is_valid:
--> 144                         raise ValueError("The parameter {} has invalid values".format(param))
    145         super(Distribution, self).__init__()
    146 

ValueError: The parameter probs has invalid values

I think the problem is caused by the validation because If I restart my kernel and comment the line numpyro.enable_validation(True) the code will run without problem. It will print 0 in my case.

If I write a similar code in Pyro with the validation enabled, I do not get an error.

import torch
import pyro
import pyro.distributions as dist
pyro.enable_validation(True)
pyro.set_rng_seed(54)

def model():
    probs = torch.tensor([0.5, 0.5, 0.])
    c = pyro.sample('c', dist.Categorical(probs=probs))
    return c

print(model())

I am using Python 3.8.5, Pyro 1.4.0 and NumPyro 0.3.0 with Ubuntu. Happy to help with what I can.

@fehiepsi
Copy link
Member

Thanks, @RaulPL! Could you help us relax the check for interval constraint at this line?

@RaulPL
Copy link
Contributor Author

RaulPL commented Sep 11, 2020

I think the line to change is this one. Go from this:

return jnp.all(x > 0, axis=-1) & (x_sum < 1 + 1e-6) & (x_sum > 1 - 1e-6)

to this:

return jnp.all(x >= 0, axis=-1) & (x_sum < 1 + 1e-6) & (x_sum > 1 - 1e-6)

I think this change will fix it. I can make the pull request if that's okay with you.

@fehiepsi
Copy link
Member

You are right, sorry for making confusion. Thanks so much for diagnosing the issue and the fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants