-
-
Notifications
You must be signed in to change notification settings - Fork 153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Gaussian Hypergeometric Function Hyp2F1
#1288
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Seems like we might need to go over the test values being used and/or potentially some subtleties of those gradient implementations.
Thanks for the quick response. The docs for scipy.special.factorial suggest it only accepts integer arguments, but it works just fine with floats. Changing the testing input types for
What could be causing this As for |
Turns out
If I could get some more clarification on how to resolve these |
Yeah, some version of the gamma function should be/is likely being used in that case. |
The magnitudes of those printed values might be indicating under/overflow of some sort. A numerical approximation is used for testing, and, if both values are so large, that could mean the test points are in an unstable/unsupported range of values for the given gradient representations (i.e. the exact functional forms of the gradients being used) and/or approximation routine. |
aesara/scalar/math.py
Outdated
class Factorial(UnaryScalarOp): | ||
""" | ||
Factorial function of a scalar or array of numbers. | ||
|
||
""" | ||
|
||
nfunc_spec = ("scipy.special.factorial", 1, 1) | ||
|
||
@staticmethod | ||
def st_impl(n): | ||
return scipy.special.factorial(n) | ||
|
||
def impl(self, n): | ||
return Factorial.st_impl(n) | ||
|
||
def grad(self, inputs, grads): | ||
(n,) = inputs | ||
(gz,) = grads | ||
return [gz * gamma(n+1) * tri_gamma(n+1)] | ||
|
||
def c_code(self, *args, **kwargs): | ||
raise NotImplementedError() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should see whether or not the existing Gamma
works for this instead. For instance, we could change factorial
below to a helper function that constructs Gamma
graphs.
I don't know if this is exactly kosher, but with |
Like this?
Because if the same is done for |
Yes, let's try that first. We may need to follow these changes with some rewrites that produce better gradients in certain cases, but that's good, because it better fits the design and intention of the library. Also, no worries if you're not familiar with the rewriting aspects of Aesara; we'll help with or follow-up on that.
Great question; the answer would be "yes"—to both questions perhaps—if those functions are part of the SciPy interface or have equivalents of some sort. If they aren't then |
Sorry, I meant that we should replace |
So replace the current
Doing this for
I want to add that |
Well, more like the following: def factorial(x: TensorLike) -> TensorVariable:
return gamma(x + 1) A user could then construct a graph with
We can continue pursuing a custom |
The typing variables seem to be throwing circular import errors:
Dispensing with type hinting is raising a new error in
If this is the case, |
I was thinking of helper/graph-constructor functions like the following: from typing import TYPE_CHECKING
import aesara
import aesara.tensor as at
if TYPE_CHECKING:
from aesara.tensor import TensorLike, TensorVariable
def factorial(n: "TensorLike") -> "TensorVariable":
return at.gamma(n + 1)
def poch(z: "TensorLike", m: "TensorLike") -> "TensorVariable":
return at.gamma(z + m) / at.gamma(z)
z = at.scalar("z")
n = at.scalar("n")
res = (factorial(n), poch(z, n))
res_fn = aesara.function([z, n], res)
res_fn(2.0, 7)
# [array(5040.), array(40320.)] |
Importing |
Ah, yes, |
Defining
All tests will pass if |
We can't use those test constructors in this case, since they're probably only intended to be used with direct |
Ok, that sounds straightforward. The current tests are also passing if the helper functions call the scalar What about the inplace definition in |
Helper functions and their respective tests have been written, but the
These are such simple helper implementations for |
aesara/scalar/math.py
Outdated
def poch(z: ScalarType, m: ScalarType) -> ScalarVariable: | ||
""" | ||
Pochhammer symbol (rising factorial) function. | ||
|
||
""" | ||
return gamma(z + m) / gamma(z) | ||
|
||
|
||
def factorial(n: ScalarType) -> ScalarVariable: | ||
""" | ||
Factorial function of a scalar or array of numbers. | ||
|
||
""" | ||
return gamma(n + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These need to be in aesara.tensor.[special|math]
and aesara.tensor.math.gamma
needs to be used, because aesara.tensor.math.gamma
is an Elemwise
Op
. That Elemwise
version of gamma
is created from the scalar aesara.scalar.math.Gamma
Op
via the scalar_elemwise
decorator in aesara.tensor.math
, which creates the Elemwise
like gamma = Elemwise(Gamma, ...)
.
In general, Elemwise
Op
s are the most common way that ScalarOp
s are used in Aesara, since they add broadcasting to scalar operations.
aesara/tensor/math.py
Outdated
@scalar_elemwise | ||
def poch(z, m): | ||
"""pochhammer symbol (rising factorial) function""" | ||
|
||
|
||
@scalar_elemwise | ||
def factorial(n): | ||
"""factorial function""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The scalar_elemwise
decorator will look for a ScalarOp
instance with the name poch
/factorial
, and this is where the errors likely start. Instead, if we remove these and define poch
and factorial
using the Elemwise
version of gamma
(i.e. the one created by scalar_elemwise
in aesara.tensor.math
) then everything should work.
Tests are passing! Still need to double-check everything, but this should be about ready for review. |
Hyp2F1
If there are no other comments, I'm going to merge the base branch changes and submit this for review. |
Sounds good! N.B. You'll have to rebase onto |
f66595f
to
ae44ce6
Compare
ae44ce6
to
cd474a5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just rebased, squashed, and removed the type hints. Looks like we need to clean up some other type-related things before we can add type hints to those functions.
cd474a5
to
db46ad7
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #1288 +/- ##
==========================================
+ Coverage 74.69% 74.71% +0.01%
==========================================
Files 194 194
Lines 49730 49826 +96
Branches 10527 10539 +12
==========================================
+ Hits 37145 37226 +81
- Misses 10262 10272 +10
- Partials 2323 2328 +5
|
Thanks! (Sorry for my delay in getting around to this.) Looks like the testing coverage report is failing because the derivatives aren't covered for Note that when/if |
Yeah, that's probably it.
Definitely. We have quite a few big improvements to |
Do you think these enhancements will be merged in the next 1-2 months? If so, I'd prefer to wait until that happens. Otherwise I'll be submitting an additional follow-up PR with the aforementioned revisions. |
The question is really whether or not we'll have the necessary improvements available for all backends (i.e. C, JAX, and Numba) in that time, and I can say that we're primarily focusing on Numba and JAX at the moment, so 1-2 months might not hold for the C backend. Regardless, multiple/staged PRs are perfectly fine—and often preferable—on our end, so no worries there. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following up on this: there are some Codecov annotations saying that we're missing lines in Hyp2F1
and Hyp2F1Der
, so we need to either confirm that those annotations are erroneous and/or add tests that reach those lines.
After that, this is good to merge.
Tests need to be added for The missing lines in |
Yeah, I think that's how we can get the missing coverage for both classes. |
I've started writing these additional tests, and could use some help with the gradients and their respective tests. Here's how the tests are set up for
They are returning this jumbled error for
Here's the testing setup for the gradients Op
The last parameter for this Op is a flag indicating which variable the derivative is taken wrt, but the |
@ColtAllen, right after your last update on this, it looks like you were asked by the PyMC [Labs] group to move this work to their fork of Aesara where it was later completed. I don't know if you're aware, but they are downstreaming Aesara, so working on it here means that it will eventually show up there. Also, it should be possible to create a PR for both repos based on the same underlying branch, in which case there would be no delay or disparity. I apologize if our review efforts weren't to your liking, especially since you seem to have felt the need to take the results elsewhere and leave things in this state. If you had external pressures and—for example—needed things to go faster, we would've preferred that you inform us so that we could have accommodated. Regardless, we know that the situation created by PyMC [Labs] and their fork can be confusing and divisive, and—because of this—we're more than willing to help however we can. |
Hey @brandonwillard, I apologize for my lack of communication in this matter. I spent most of December traveling and have been playing catch-up since I got back last night. I created this PR because it's a key backend requirement for the btyd library I've been working on, for which I've also been the sole developer. The PyMC [Labs] team are working on a very similar library and reached out to me about merging efforts. I was thrilled to join them because communities are paramount to the success and survival of open-source projects (plus I prefer working with others over going solo). It was determined in the downstream PR the gradients for the I attempted a PR beyond my abilities and ultimately abandoned one project in favor of another, so if anyone is at fault here, it's me. Your comments were very helpful while I was working on this - in fact, what I enjoyed most about this PR was checking my email and getting a notification you responded. |
Thought I'd wrap up this PR as a courtesy. Gradients and respective tests have been added. |
d86e2e3
to
092fb69
Compare
092fb69
to
c692690
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can merge this, but we need to create a follow-up issue for a C/Cython/Numba implementation of Hyp2F1Der.impl
.
This PR is in reference to #1046 and adds
Op
s forscipy.special.hyp2f1
,scipy.special.poch
, andscipy.special.factorial
. Currently all three are failing thetest_grad
tensor broadcasting tests inaesara/tests/tensor/test_math_scipy.py
. Here are the results frompytest
:TestHyp2F1Broadcast & TestHyp2F1InplaceBroadcast
TestPochBroadcast & TestPochInplaceBroadcast
TestFactorialBroadcast & TestFactorialInplaceBroadcast