-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add softmax to math #5279
Add softmax to math #5279
Conversation
19f6066
to
97341bb
Compare
Codecov Report
@@ Coverage Diff @@
## main #5279 +/- ##
=======================================
Coverage 80.40% 80.41%
=======================================
Files 82 82
Lines 14126 14132 +6
=======================================
+ Hits 11358 11364 +6
Misses 2768 2768
|
Some scary looking errors on windows, most likely triggered by the new Aesara version: https://github.com/pymc-devs/pymc/runs/4605536352?check_suite_focus=true#step:7:21 aesara-devs/aesara#701 would be my first guess |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is blocked by aesara-devs/aesara#707 and will be delayed until we can upgrade to an Aesara release that includes a fix.
97341bb
to
8af7d80
Compare
def softmax(x, axis=None): | ||
# Ignore vector case UserWarning issued by Aesara. This can be removed once Aesara | ||
# drops that warning | ||
with warnings.catch_warnings(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked locally that this does not screw up UserWarnings elsewhere
This PR adds softmax and log_softmax wrappers to the new Aesara Ops which accept the axis argument. The wrappers are there just to suppress the deprecation warnings issue on the Aesara side.
Closes #4226