-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Implement unconstraining transform for LKJCorr #7380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement unconstraining transform for LKJCorr #7380
Conversation
|
pymc/distributions/transforms.py
Outdated
|
||
# Are the diagonals always guaranteed to be positive? | ||
# I don't know, so we'll use abs | ||
row_norms = 1/pt.abs(pt.diag(chol)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, always positive. You don't need abs here
pymc/distributions/transforms.py
Outdated
) | ||
|
||
def _jacobian(self, value, *inputs): | ||
return pt.jacobian( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pt.jacobian
can be quite expensive, because it requires us to loop over every input and compute the associated symbolic gradients. There's a closed form solution for the log-det jacobian in the TFP code, so you can eliminate this method and implement the closed form log-det jac:
n = ps.shape(y)[-1]
return -tf.reduce_sum(
tf.range(2, n + 2, dtype=y.dtype) * tf.math.log(tf.linalg.diag_part(y)),
axis=-1)
diag_part
would just be pt.diagonal(y, axis1=-2, axis2=-1)
. That will account for potential batching on y
. So something like:
n = y.shape[-1]
return -(pt.arange(2, n+2, dtype=y.dtype) * pt.log(pt.diagonal(y, axis1=-2, axis2=-1))).sum(axis=-1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you point me to some info on how that's derived? Going to need to modify it to work with the correlation matrix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind, found in the comments in TFP _inverse_log_det_jacobian
pymc/distributions/transforms.py
Outdated
row_indices, col_indices = np.tril_indices(self.n, -1) | ||
return ( | ||
pytensor.shared(row_indices), | ||
pytensor.shared(col_indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no need to save these as shared variables, you can use the numpy indices directly. Making the numpy indices is pretty cheap, I'm not sure its worth it to cache them
pymc/distributions/transforms.py
Outdated
|
||
return unconstrained[self.tril_r_idxs, self.tril_c_idxs] | ||
|
||
def backward(self, value, *inputs, foo=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to check that these functions match the expected outputs from TFP. I used the test case from the tfp docs and got the wrong values -- array([0.89442719, 0.81649658, 0.91287093])
vs the reference solution
array([[1. , 0. , 0. ],
[0.70710678, 0.70710678, 0. ],
[0.66666667, 0.66666667, 0.33333333]])
You did some extra research so I might be missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like this matches tfp:
def backward(self, value, *inputs):
"""
Convert unconstrained real numbers to the off-diagonal elements of the
cholesky decomposition of a correlation matrix.
"""
def unpack_upper_tril_with_eye_diag(x, core_shape):
"""1D allocation case"""
return pt.set_subtensor(pt.eye(core_shape)[*np.tril_indices(core_shape, k=-1)], x[::-1])
value = pt.as_tensor_variable(value)
core_shape = value.type.shape[-1]
# Vectorize the 1D case to handle potential batch dimensions
out = pt.vectorize(partial(unpack_upper_tril_with_eye_diag, core_shape=core_shape), '(n)->(n,n)')(value)
# Vector L2 norm without .real call to speed things up a bit
norm = pt.sqrt(pt.sum((out ** 2), axis=-1, keepdims=True))
return out / norm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for that code above, there's some great stuff I can reuse.
Need to address this comment first, since actually working in the first place is really fundamental. Here's a notebook that demonstrates that this implementation does replicate the original reference implementation from TFP: https://colab.research.google.com/drive/1BBNNfBUNJPGT_7MxVboTqvRegJ-TUamc?usp=sharing
Here's why it didn't work for you:
- PyMC implementation needs to output the upper triangular elements of the correlation matrix, whereas the TFP implementation outputs a Cholesky factor.
- Differences in indexing off-diagonal elements. TFP actually fills in off-diagonal elements in a clockwise spiral, whereas np.triu_indices is row major. I notice you reverse the elements in the code above which is correct for 3x3 but not for larger matrices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got access denied to the colab :)
I'm not surprised my code doesn't work, but I'm glad you know why. In the general case, tensorflow concatenates a reflected copy of the array to itself then reshapes and masks out the lower/upper triangle -- see here if you haven't already. There's no reason why we couldn't just do that.
I'm not sure that we need to copy their output 1:1 -- after all, the important thing is that we can go from unconstrained samples to a valid cholesky decomposed correlation matrix. Is the order we put the numbers into the matrix relevant? I'm not sure, but my instinct is no. On the other hand, if we copy 1:1 we can be sure it's right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyMC implementation needs to output the upper triangular elements of the correlation matrix, where the TFP implementation outputs a Cholesky factor.
Are you sure? I thought the upper triangular elements are the cholesky factorized correlation matrix. If you're right though we just need to add a matmul to the end right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies - Link is now open with Viewer permission and I've made you an Editor.
I don't think the order we insert the off-diagonal elements into an array is very important, but it is needed in order to compare results between this implementation and the one in TFP. I would suggest sticking with np.triu_indices
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure? I thought the upper triangular elements are the cholesky factorized correlation matrix. If you're right though we just need to add a matmul to the end right?
Yes, you can see this by looking at the implementation of LKJCorr. I originally thought the same thing, implemented the transform accordingly, then was surprised that non-posdef matrices were generated. 🤦
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Sounds like you have a good handle on things. I think what would be a really important next step would be to add a test that your implementation correctly makes a round trip from
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome - I can do that
self.n = n | ||
self.m = int(n*(n-1)/2) # number of off-diagonal elements | ||
self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices() | ||
self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See below, not sure we need to cache these. __init__
is probably unnecessary
pymc/distributions/transforms.py
Outdated
jac = self._jacobian(value) | ||
return pt.log(pt.linalg.det(jac)) | ||
|
||
def forward(self, value, *inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See below. I'm pretty sure this needs to go from matrix to vector (to match the tfp case) @junpenglao might know for sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 it is better for the unbounded being a vector.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I am a bit confused and don't understand what you mean.
Specifically, do you mean this function needs to work along the last axis for arrays of arbitrary number of dimensions, and that the current iteration assumes that value
will only have dimension 1?
@@ -1579,7 +1579,9 @@ def logp(value, n, eta): | |||
|
|||
@_default_transform.register(_LKJCorr) | |||
def lkjcorr_default_transform(op, rv): | |||
return MultivariateIntervalTransform(-1.0, 1.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you delete this transform class as well? It was a (wrong) patch to the problem you're solving
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can do. Just to confirm, you don't consider MultivariateIntervalTransform to be part of pymc's public API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, can be removed without worries
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok - great
self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices() | ||
|
||
def _generate_tril_indices(self): | ||
row_indices, col_indices = np.tril_indices(self.n, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if it matters but there is a pt.tril_indices
and pt.triu_indices
so no need to eval n
. If it's already restricted to be constant elsewhere (like the logp), then it's fine either way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's good practice to use the pt
version, even if n
is fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally tried to use the pt
version, but one of the function calls required constant values. However, I've made so many changes, that might no longer be the case. I'll try the pt
version again and see if I can get it to work.
Hi, It's unlikely I'm going to have any time to work on this for the next 6 months. The hardest part is coming up with a closed form solution for log_det_jac, which I don't think I'm very close to doing. |
Thanks for the update @johncant and for pushing this as far as you did. |
computed_log_jac_det = transform.log_jac_det(y).eval() | ||
|
||
# Expected log determinant: 0 (since row norms are 1) | ||
expected_log_jac_det = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weak test. Tell it to compare with pytensor jacobian machinery with a non-trivial input. And to reuse test code that already exists to do that
I've ported this bijector from
tensorflow
and added toLKJCorr
. This ensures that initial samples drawn fromLKJCorr
are positive definite, which fixes #7101 . Sampling now completes successfully with no divergences.There are several parts I'm not comfortable with:
n
parameter fromop
orrv
withouteval
ing any pytensors?@fonnesbeck @twiecki @jessegrabowski @velochy - please could you take a look? I would like to make sure that this fix makes sense before adding tests and making the linters pass.
Notes:
forward
intensorflow_probability
isbackward
inpymc
Description
Backward method
Forward method
log_jac_det
This was quite complicated to implement, so I used the symbolic jacobian.
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7380.org.readthedocs.build/en/7380/