Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conditional distribution with MvNormal does not always work #99

Open
MercuryBench opened this issue Oct 27, 2024 · 0 comments
Open

Conditional distribution with MvNormal does not always work #99

MercuryBench opened this issue Oct 27, 2024 · 0 comments

Comments

@MercuryBench
Copy link
Contributor

Reproducible example: Let w be a categorical variable with values 0 or 1 (probability 0.5 each). Let x|w be a normal distribution with mean w and scale 0.01. this works in 1d:

chain_rule = OrderedDict()
chain_rule['w'] = dists.Categorical(p = np.array([0.5, 0.5]))
chain_rule['z'] = dists.Cond(lambda x: dists.Normal(loc=x['w'], scale=0.01), dim=1)
chain_rule['z'] = dists.Cond(lambda x: dists.MvNormal(loc=x['w']*np.array([1,0]), scale=0.01), dim=1)
mydist = dists.StructDist(chain_rule)
mydist.rvs(5)

The reason why this works builds mainly on the feature of numpy.random.normal which allows input like
np.random.normal(np.array([0,10,20]), size=3), i.e., the mean can be provided in a vectorised format. This does not work for np.random.multivariate_normal, which is why the following example does not work:

chain_rule = OrderedDict()
chain_rule['w'] = dists.Categorical(p = np.array([0.5, 0.5]))
chain_rule['z'] = dists.Cond(lambda x: dists.MvNormal(loc=x['w']*np.array([1,0]), scale=0.01), dim=2)
mydist = dists.StructDist(chain_rule)
mydist.rvs(5)

(we get the error)

chain_rule['z'] = dists.Cond(lambda x: dists.MvNormal(loc=x['w']*np.array([1,0]), scale=0.01), dim=2)
ValueError: operands could not be broadcast together with shapes (5,) (2,) 

Fundamentally, I believe that the only recourse here is to sample the rvs sequentially, ie loop through the samples to be provided (sample w first, then w|z, and do this as many times as necessary). This will be much slower, of course.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant