Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harmonize softplus implementations #884

Merged
merged 1 commit into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions aesara/link/jax/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,18 @@ def psi(x):
@jax_funcify.register(Softplus)
def jax_funcify_Softplus(op, **kwargs):
def softplus(x):
# This expression is numerically equivalent to the Aesara one
# It just contains one "speed" optimization less than the Aesara counterpart
return jnp.where(
x < -37.0, jnp.exp(x), jnp.where(x > 33.3, x, jnp.log1p(jnp.exp(x)))
x < -37.0,
jnp.exp(x),
jnp.where(
x < 18.0,
jnp.log1p(jnp.exp(x)),
jnp.where(
x < 33.3,
x + jnp.exp(-x),
x,
),
),
)

return softplus
52 changes: 20 additions & 32 deletions aesara/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import os
import warnings
from textwrap import dedent

import numpy as np
import scipy.special
Expand Down Expand Up @@ -1134,7 +1135,8 @@ class Softplus(UnaryScalarOp):
r"""
Compute log(1 + exp(x)), also known as softplus or log1pexp

This function is numerically more stable than the naive approach.
This function is numerically faster than the naive approach, and does not overflow
for large values of x.

For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
Expand Down Expand Up @@ -1172,52 +1174,38 @@ def grad(self, inp, grads):
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(z,) = out
# The boundary constants were obtained by looking at the output of
# python commands like:
# import numpy, aesara
# dt='float32' # or float64
# for i in range(750):
# print i, repr(numpy.log1p(numpy.exp(_asarray([i,-i], dtype=dt))))
# the upper boundary check prevents us from generating inf, whereas the
# the lower boundary check prevents using exp when the result will be 0 anyway.
# The intermediate constants are taken from Machler (2012).

# We use the float32 limits for float16 for now as the
# computation will happen in float32 anyway.
# We use the same limits for all precisions, which may be suboptimal. The reference
# paper only looked at double precision
if node.inputs[0].type in float_types:
if node.inputs[0].type == float64:
return (
"""
%(z)s = (
%(x)s < -745.0 ? 0.0 :
%(x)s < -37.0 ? exp(%(x)s) :
%(x)s < 18.0 ? log1p(exp(%(x)s)) :
%(x)s < 33.3 ? %(x)s + exp(-%(x)s) :
%(x)s
return dedent(
f"""
{z} = (
{x} < -37.0 ? exp({x}) :
{x} < 18.0 ? log1p(exp({x})) :
{x} < 33.3 ? {x} + exp(-{x}) :
{x}
);
"""
% locals()
)
else:
return (
"""
%(z)s = (
%(x)s < -103.0f ? 0.0 :
%(x)s < -37.0f ? exp(%(x)s) :
%(x)s < 18.0f ? log1p(exp(%(x)s)) :
%(x)s < 33.3f ? %(x)s + exp(-%(x)s) :
%(x)s
return dedent(
f"""
{z} = (
{x} < -37.0f ? exp({x}) :
{x} < 18.0f ? log1p(exp({x})) :
{x} < 33.3f ? {x} + exp(-{x}) :
{x}
);
"""
% locals()
)
else:
raise NotImplementedError("only floatingpoint is implemented")

def c_code_cache_version(self):
v = super().c_code_cache_version()
if v:
return (2,) + v
return (3,) + v
else:
return v

Expand Down