Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add softmax to math #5279

Merged
merged 2 commits into from
Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
- `pm.Data` now passes additional kwargs to `aesara.shared`/`at.as_tensor`. [#5098](https://github.com/pymc-devs/pymc/pull/5098).
- Univariate censored distributions are now available via `pm.Censored`. [#5169](https://github.com/pymc-devs/pymc/pull/5169)
- Nested models now inherit the parent model's coordinates. [#5344](https://github.com/pymc-devs/pymc/pull/5344)
- `softmax` and `log_softmax` functions added to `math` module (see [#5279](https://github.com/pymc-devs/pymc/pull/5279)).
- ...


Expand Down
16 changes: 16 additions & 0 deletions pymc/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,22 @@ def invlogit(x, eps=None):
return at.sigmoid(x)


def softmax(x, axis=None):
# Ignore vector case UserWarning issued by Aesara. This can be removed once Aesara
# drops that warning
with warnings.catch_warnings():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked locally that this does not screw up UserWarnings elsewhere

warnings.simplefilter("ignore", UserWarning)
return at.nnet.softmax(x, axis=axis)


def log_softmax(x, axis=None):
# Ignore vector case UserWarning issued by Aesara. This can be removed once Aesara
# drops that warning
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
return at.nnet.logsoftmax(x, axis=axis)


def logbern(log_p):
if np.isnan(log_p):
raise FloatingPointError("log_p can't be nan.")
Expand Down
8 changes: 2 additions & 6 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import numpy.random as nr
import scipy.linalg
import scipy.special

from aesara.graph.fg import MissingInputError
from aesara.tensor.random.basic import BernoulliRV, CategoricalRV
Expand Down Expand Up @@ -608,7 +609,7 @@ def metropolis_proportional(self, q, logp, logp_curr, dim, k):
if candidate_cat != given_cat:
q.data[dim] = candidate_cat
log_probs[candidate_cat] = logp(q)
probs = softmax(log_probs)
probs = scipy.special.softmax(log_probs, axis=0)
prob_curr, probs[given_cat] = probs[given_cat], 0.0
probs /= 1.0 - prob_curr
proposed_cat = nr.choice(candidates, p=probs)
Expand Down Expand Up @@ -995,11 +996,6 @@ def sample_except(limit, excluded):
return candidate


def softmax(x):
e_x = np.exp(x - np.max(x))
return e_x / np.sum(e_x, axis=0)


def delta_logp(point, logp, vars, shared):
[logp0], inarray0 = pm.join_nonshared_inputs(point, [logp], vars, shared)

Expand Down
23 changes: 23 additions & 0 deletions pymc/tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
kronecker,
log1mexp,
log1mexp_numpy,
log_softmax,
logdet,
logdiffexp,
logdiffexp_numpy,
probit,
softmax,
)
from pymc.tests.helpers import SeededTest, verify_grad

Expand Down Expand Up @@ -265,3 +267,24 @@ def test_invlogit_deprecation_warning():
assert not record

assert np.isclose(res, res_zero_eps)


@pytest.mark.parametrize(
"aesara_function, pymc_wrapper",
[
(at.nnet.softmax, softmax),
(at.nnet.logsoftmax, log_softmax),
],
)
def test_softmax_logsoftmax_no_warnings(aesara_function, pymc_wrapper):
"""Test that wrappers for aesara functions do not issue Warnings"""

vector = at.vector("vector")
with pytest.warns(None) as record:
aesara_function(vector)
warnings = {warning.category for warning in record.list}
assert warnings == {UserWarning, FutureWarning}

with pytest.warns(None) as record:
pymc_wrapper(vector)
assert not record