-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LKJCorr and LKJCholeskyCov refactor #5382
Conversation
d382369
to
0b5f8d1
Compare
Codecov Report
@@ Coverage Diff @@
## main #5382 +/- ##
==========================================
+ Coverage 80.43% 81.39% +0.95%
==========================================
Files 82 82
Lines 14159 14213 +54
==========================================
+ Hits 11389 11568 +179
+ Misses 2770 2645 -125
|
beta -= 0.5 | ||
y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size, random_state=rng) | ||
z = stats.norm.rvs(loc=0, scale=1, size=(size, mp1), random_state=rng) | ||
z = z / np.sqrt(np.einsum("ij,ij->i", z, z))[..., np.newaxis] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tomicapretto, continuing from #4784 (comment) here is where I think the not yet refactored LKJCholeskyCov
random method was wrong. My understanding is that it was trying to generalize the pre-existing code in LJKCorr
(removed in this PR) to allow for more flexible sizes, but in doing so altered the meaning of this einsum. The still "buggy" code there for reference is this:
pymc/pymc/distributions/multivariate.py
Line 1183 in 0b5f8d1
z = z / np.sqrt(np.einsum("ij,ij->j", z, z)) |
CC @lucianopaz, I think you wrote this code (for the cholesky) originally in #3293, do you have a chance to ping in?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember that PR being centered around making a mixture distribution of MvNormals work, and to be able to sample from their prior. The flexible size
stuff came from over there. I hope that I did not mess up the einsum back then, but I honestly don't remember why I had written "j"
instead of "i"
, and I don't remember the algorithm of the rng at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I used this new more flexible logic for the LKJCorrRV, the random test failed, and that seems to be the stricter test for the rng that we have.
Then I started debugging line by line, and the einsum index was what changed the results between the old logic in LKJCorrRV
and the more flexible one in LKJCholeskyRV
.
0b5f8d1
to
65aeaeb
Compare
2bbec95
to
db62693
Compare
LKJCholeskyCov is also refactored! |
841e825
to
db0b762
Compare
Tests are passing! |
…butionRandom` * Fixes bug in returned samples from `Wishart` when `size=1`
75bcb20
to
a353584
Compare
a353584
to
c0685dc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great. I don't know what to say about the einsum though.
beta -= 0.5 | ||
y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size, random_state=rng) | ||
z = stats.norm.rvs(loc=0, scale=1, size=(size, mp1), random_state=rng) | ||
z = z / np.sqrt(np.einsum("ij,ij->i", z, z))[..., np.newaxis] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember that PR being centered around making a mixture distribution of MvNormals work, and to be able to sample from their prior. The flexible size
stuff came from over there. I hope that I did not mess up the einsum back then, but I honestly don't remember why I had written "j"
instead of "i"
, and I don't remember the algorithm of the rng at all.
Changes: * compute_corr now defaults to True * LKJCholeskyCov now also provides a `.dist` interface
c0685dc
to
5f43bb4
Compare
🥳 |
This is really amazing work! I wanted to give it a go and thought I'd try the LKJ example notebook. However that didn't seem to work unfortunately! I messed around with things a little bit but couldn't get the sampling to work. Sorry, it's entirely possible that my python environment isn't exactly right -- I did upgrade to the aesara and aeppl versions listed here, but maybe I missed something. In any case, it might be worth trying the example notebook @ricardoV94 ! |
@martiningram I haven't tried to run the notebook but I see the code has a slight issue, relative to V4. The sd_dist should have the same shape as Ill try and run the notebook some time soon |
Forgot to ask, what error are you seeing? |
The first is in cell 7, following: packed_L.tag.test_value.shape That gives: ---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/tmp/ipykernel_30498/1180155108.py in <module>
----> 1 packed_L.tag.test_value.shape
AttributeError: 'tuple' object has no attribute 'tag' If I skip that line and the next, running coords = {"axis": ["y", "z"], "axis_bis": ["y", "z"], "obs_id": np.arange(N)}
with pm.Model(coords=coords) as model:
chol, corr, stds = pm.LKJCholeskyCov(
"chol", n=2, eta=2.0, sd_dist=pm.Exponential.dist(1.0), compute_corr=True
)
cov = pm.Deterministic("cov", chol.dot(chol.T), dims=("axis", "axis_bis")) seems to work, and the next cell also, but sampling gives: Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/link/vm.py in __call__(self)
308 ):
--> 309 thunk()
310 for old_s in old_storage:
~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/graph/op.py in rval(p, i, o, n)
507 def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
--> 508 r = p(n, [x[0] for x in i], o)
509 for o in node.outputs:
~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/tensor/random/op.py in perform(self, node, inputs, outputs)
381
--> 382 smpl_val = self.rng_fn(rng, *(args + [size]))
383
~/projects/pymc3_vs_stan/pymc/pymc/distributions/multivariate.py in rng_fn(self, rng, n, eta, D, size)
1149
-> 1150 D = D.reshape(flat_size, n)
1151 C *= D[..., :, np.newaxis] * D[..., np.newaxis, :]
ValueError: cannot reshape array of size 1 into shape (1,2)
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
/tmp/ipykernel_30498/1518530166.py in <module>
1 with model:
----> 2 trace = pm.sample(
3 random_seed=RANDOM_SEED,
4 init="adapt_diag",
5 return_inferencedata=True,
~/projects/pymc3_vs_stan/pymc/pymc/sampling.py in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
494 # By default, try to use NUTS
495 _log.info("Auto-assigning NUTS sampler...")
--> 496 initial_points, step = init_nuts(
497 init=init,
498 chains=chains,
~/projects/pymc3_vs_stan/pymc/pymc/sampling.py in init_nuts(init, chains, n_init, model, seeds, progressbar, jitter_max_retries, tune, initvals, **kwargs)
2318 ]
2319
-> 2320 initial_points = _init_jitter(
2321 model,
2322 initvals,
~/projects/pymc3_vs_stan/pymc/pymc/sampling.py in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries)
2195
2196 if not jitter:
-> 2197 return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)]
2198
2199 initial_points = []
~/projects/pymc3_vs_stan/pymc/pymc/sampling.py in <listcomp>(.0)
2195
2196 if not jitter:
-> 2197 return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)]
2198
2199 initial_points = []
~/projects/pymc3_vs_stan/pymc/pymc/initial_point.py in inner(seed, *args, **kwargs)
214 new_rng = np.random.Generator(seed)
215 rng.set_value(new_rng, True)
--> 216 values = func(*args, **kwargs)
217 return dict(zip(varnames, values))
218
~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/compile/function/types.py in __call__(self, *args, **kwargs)
967 try:
968 outputs = (
--> 969 self.fn()
970 if output_subset is None
971 else self.fn(output_subset=output_subset)
~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/link/vm.py in __call__(self)
311 old_s[0] = None
312 except Exception:
--> 313 raise_with_op(self.fgraph, node, thunk)
314
315
~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/link/utils.py in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
523 # Some exception need extra parameter in inputs. So forget the
524 # extra long error message in that case.
--> 525 raise exc_value.with_traceback(exc_trace)
526
527
~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/link/vm.py in __call__(self)
307 self.thunks, self.nodes, self.post_thunk_clear
308 ):
--> 309 thunk()
310 for old_s in old_storage:
311 old_s[0] = None
~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/graph/op.py in rval(p, i, o, n)
506 # default arguments are stored in the closure of `rval`
507 def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
--> 508 r = p(n, [x[0] for x in i], o)
509 for o in node.outputs:
510 compute_map[o][0] = True
~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/tensor/random/op.py in perform(self, node, inputs, outputs)
380 rng_var_out[0] = rng
381
--> 382 smpl_val = self.rng_fn(rng, *(args + [size]))
383
384 if (
~/projects/pymc3_vs_stan/pymc/pymc/distributions/multivariate.py in rng_fn(self, rng, n, eta, D, size)
1148 C = LKJCorrRV._random_corr_matrix(rng, n, eta, flat_size)
1149
-> 1150 D = D.reshape(flat_size, n)
1151 C *= D[..., :, np.newaxis] * D[..., np.newaxis, :]
1152
ValueError: cannot reshape array of size 1 into shape (1,2)
Apply node that caused the error: _lkjcholeskycov_rv{1, (0, 0, 1), floatX, False}(RandomStateSharedVariable(<RandomState(PCG64) at 0x7F3D4D447E40>), TensorConstant{[]}, TensorConstant{11}, TensorConstant{2}, TensorConstant{2.0}, exponential_rv{0, (0,), floatX, False}.out)
Toposort index: 1
Inputs types: [RandomStateType, TensorType(int64, (0,)), TensorType(int64, ()), TensorType(int32, ()), TensorType(float64, ()), TensorType(float64, ())]
Inputs shapes: ['No shapes', (0,), (), (), (), ()]
Inputs strides: ['No strides', (8,), (), (), (), ()]
Inputs values: [RandomState(PCG64) at 0x7F3D4D447E40, array([], dtype=int64), array(11), array(2, dtype=int32), array(2.), array(2.01457564)]
Outputs clients: [[], [Elemwise{second,no_inplace}(chol, TensorConstant{(1,) of 0.0})]]
Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/distributions/multivariate.py", line 1181, in __new__
return super().__new__(cls, name, eta, n, sd_dist, **kwargs)
File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/distributions/distribution.py", line 266, in __new__
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/distributions/distribution.py", line 165, in _make_rv_and_resize_shape
rv_out = cls.dist(*args, **kwargs)
File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/distributions/multivariate.py", line 1203, in dist
return super().dist([n, eta, sd_dist], size=size, **kwargs)
File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/distributions/distribution.py", line 353, in dist
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
File "/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/tensor/random/op.py", line 293, in __call__
res = super().__call__(rng, size, dtype, *args, **kwargs)
File "/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/graph/op.py", line 283, in __call__
node = self.make_node(*inputs, **kwargs)
File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/distributions/multivariate.py", line 1134, in make_node
return super().make_node(rng, size, dtype, n, eta, D)
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node. In case it's helpful, I also modified the notebook somewhat to remove the failing lines and to try to implement your suggestion of giving the Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AB8F8E40>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AB8F8E40>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AA9BD340>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AA9BD540>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AA9BD640>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
warnings.warn(
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [chol, μ]
4.92% [394/8000 00:00<00:17 Sampling 4 chains, 0 divergences]
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:258: RuntimeWarning: divide by zero encountered in true_divide
np.divide(1, self._stds, out=self._inv_stds)
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:237: RuntimeWarning: invalid value encountered in multiply
return np.multiply(self._var, x, out=out)
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:258: RuntimeWarning: divide by zero encountered in true_divide
np.divide(1, self._stds, out=self._inv_stds)
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:237: RuntimeWarning: invalid value encountered in multiply
return np.multiply(self._var, x, out=out)
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:258: RuntimeWarning: divide by zero encountered in true_divide
np.divide(1, self._stds, out=self._inv_stds)
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:237: RuntimeWarning: invalid value encountered in multiply
return np.multiply(self._var, x, out=out)
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:258: RuntimeWarning: divide by zero encountered in true_divide
np.divide(1, self._stds, out=self._inv_stds)
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:237: RuntimeWarning: invalid value encountered in multiply
return np.multiply(self._var, x, out=out)
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AA9BD340>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AA9BD340>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AA9BD340>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AA9BD340>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
warnings.warn(
---------------------------------------------------------------------------
RemoteTraceback Traceback (most recent call last)
RemoteTraceback:
"""
Traceback (most recent call last):
File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/parallel_sampling.py", line 125, in run
self._start_loop()
File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/parallel_sampling.py", line 178, in _start_loop
point, stats = self._compute_point()
File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/parallel_sampling.py", line 203, in _compute_point
point, stats = self._step_method.step(self._point)
File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/arraystep.py", line 286, in step
return super().step(point)
File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/arraystep.py", line 208, in step
step_res = self.astep(q)
File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/base_hmc.py", line 164, in astep
self.potential.raise_ok(q0.point_map_info)
File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py", line 308, in raise_ok
raise ValueError("\n".join(errmsg))
ValueError: Mass matrix contains zeros on the diagonal.
The derivative of RV `chol_cholesky-cov-packed__`.ravel()[[0 1 2]] is zero.
The derivative of RV `μ`.ravel()[[0 1]] is zero.
"""
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
ValueError: Mass matrix contains zeros on the diagonal.
The derivative of RV `chol_cholesky-cov-packed__`.ravel()[[0 1 2]] is zero.
The derivative of RV `μ`.ravel()[[0 1]] is zero.
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_30650/2034248556.py in <module>
1 with model:
----> 2 trace = pm.sample(
3 random_seed=RANDOM_SEED,
4 init="adapt_diag",
5 )
~/projects/pymc3_vs_stan/pymc/pymc/sampling.py in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
566 _print_step_hierarchy(step)
567 try:
--> 568 trace = _mp_sample(**sample_args, **parallel_args)
569 except pickle.PickleError:
570 _log.warning("Could not pickle model, sampling singlethreaded.")
~/projects/pymc3_vs_stan/pymc/pymc/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, **kwargs)
1483 try:
1484 with sampler:
-> 1485 for draw in sampler:
1486 trace = traces[draw.chain - chain]
1487 if trace.supports_sampler_stats and draw.stats is not None:
~/projects/pymc3_vs_stan/pymc/pymc/parallel_sampling.py in __iter__(self)
458
459 while self._active:
--> 460 draw = ProcessAdapter.recv_draw(self._active)
461 proc, is_last, draw, tuning, stats, warns = draw
462 self._total_draws += 1
~/projects/pymc3_vs_stan/pymc/pymc/parallel_sampling.py in recv_draw(processes, timeout)
347 else:
348 error = RuntimeError("Chain %s failed." % proc.chain)
--> 349 raise error from old_error
350 elif msg[0] == "writing_done":
351 proc._readable = True
RuntimeError: Chain 2 failed. Hope some of this is helpful! |
@martiningram, those Surprisingly this is not even related to the Anyway, thanks for bringing it up. There is definitely a bug lurking around |
Oh weird, thanks for taking such a close look at this! Interesting that it's probably just due to |
This supersedes #4784
Tests are currently failing due to aesara-devs/aesara#786
With local patch, they pass. Also now the random method of LKJCorr works properly for arbitrary sizes, even though the logp method is restricted to 2D values due to reliance on
matrix_pos_def
TODO:
tensor.take
bug aesara-devs/aesara#786Closes #4686