-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Derive logprob for hyperbolic and error transformations #6664
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6664 +/- ##
==========================================
+ Coverage 91.96% 92.00% +0.03%
==========================================
Files 94 95 +1
Lines 15927 16101 +174
==========================================
+ Hits 14647 14813 +166
- Misses 1280 1288 +8
|
Before I have just tested the logp matches against equivalent RVs forms such as So then it boils down to:
def test_erf_logp():
base_rv = pt.random.normal(0.5, 1, name="base_rv") # Something not centered around 0 is usually better
rv = pt.erf(base_rv)
vv = rv.clone()
rv_logp = logp(rv, vv)
assert_no_rvs(rv_logp)
transform = ErfTransform
expected_logp = logp(rv, transform.backward(vv)) + transform.log_jac_det(vv)
vv_test = np.array(0.25) # Arbitrary test value
np.testing.assert_almost_equal(
rv_logp.eval({vv: vv_test}),
expected_logp.eval({vv: vv_test}),
) You can probably parametrize and test all new functions with the same test. Alternatively you can try to hijack pymc/tests/logprob/test_transforms.py Line 212 in 5d68bf3
The test now assumes you are testing only |
@ricardoV94 I've been working on 2. and the test now runs however its throwing an assertion error. The test is: @pytest.mark.parametrize("transform", [ErfTransform])
def test_erf_logp(transform):
base_rv = pt.random.normal(0.5, 1, name="base_rv") # Something not centered around 0 is usually better
rv = pt.erf(base_rv)
vv = rv.clone()
rv_logp = joint_logprob({rv: vv})
transform = transform()
expected_logp = joint_logprob({rv: transform.backward(vv)}) + transform.log_jac_det(vv)
vv_test = np.array(0.25) # Arbitrary test value
np.testing.assert_almost_equal(
rv_logp.eval({vv: vv_test}),
expected_logp.eval({vv: vv_test}),
) This gives the assertion error:
They're close but still quite a difference. I'm not sure if this is the way I've written the logp in the test or that the internal transform functions are wrong. Any ideas? |
Oh my example was wrong. You want to compare with the This passes locally: import numpy as np
import pytensor.tensor as pt
from pymc.logprob.basic import logp
from pymc.logprob.transforms import ErfTransform
base_rv = pt.random.normal(0.5, 1, name="base_rv") # Something not centered around 0 is usually better
rv = pt.erf(base_rv)
vv = rv.clone()
rv_logp = logp(rv, vv)
transform = ErfTransform()
expected_logp = logp(base_rv, transform.backward(vv)) + transform.log_jac_det(vv)
vv_test = np.array(0.25) # Arbitrary test value
np.testing.assert_almost_equal(
rv_logp.eval({vv: vv_test}),
expected_logp.eval({vv: vv_test}),
) |
Cheers will make the change! |
|
Woops clicked the wrong button didn't mean to close! So test 2. now works for all Transforms however I'm having an issue with the Note for test 2. I had to make changes to the test by adding a switch statement and editing the switch statement on line 416 in transforms.py to take input_logprob AND jacobian because of the descrepency of returning nans vs. -infs as if input_logprob is nan then this also returns nan and not -inf. |
No problem :)
I think there must have been an error in your I think it's fine to use the default implementation (for the cases where it works). I didn't try to find what was the error. |
Great, I'll take a look at what you did and see if I can try and implement it with the other transforms. |
OK all done, all tests pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I don't know why the coverage shows some of the new transforms not being covered, just a fluke?
I just have a question about a change below.
pymc/logprob/transforms.py
Outdated
@@ -391,7 +419,7 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa | |||
jacobian = jacobian.sum(axis=tuple(range(-ndim_supp, 0))) | |||
|
|||
# The jacobian is used to ensure a value in the supported domain was provided | |||
return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian) | |||
return pt.switch(pt.isnan(input_logprob + jacobian), -np.inf, input_logprob + jacobian) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we revert this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potentially let me check. The reason I did that was because it meant that we return -np.inf
consistently when input_logprob = nan
, which is the case for some of the transforms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay seems that isn't the case anymore and tests pass with the reverted change
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
Awesome work @LukeLB! Looking forward to your next PR :) |
Thanks @ricardoV94 it's been a pleasure! Thanks for reviewing :) |
What is this PR about?
I have implemented additional Elemwise transformations as suggested in issue #6631. Specifically, this pull request adds cosh, sinh, tanh, erf, erfc, and erfcx functions. I plan to address the other suggested transformations in a separate pull request, as they require a more significant rewrite of existing functions. However, if it is preferred to include them all in one pull request, I'm happy to do so.
Please note that this is still a work in progress, and I have not yet written any tests for the new Transforms. I would appreciate some guidance on how to design these tests as its not clear to me what I should be testing them against.
Also for the erfcx transform it would be great double check my math is correct, for the backward I have rewrote a matlab function and for the log jacobian determinant I used wolfram alpha to get the derivative of erfcx.
...
Checklist
Major / Breaking Changes
find_measureable_transforms()
as it was getting quite largeNew features
Transforms for:
Bugfixes
Documentation
Maintenance
📚 Documentation preview 📚: https://pymc--6664.org.readthedocs.build/en/6664/