Skip to content

Commit

Permalink
Update dependency to aesara 2.0.8, and necessary fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored and twiecki committed May 12, 2021
1 parent 2c372ef commit de74ff6
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def rho2sigma(rho):
"""
`rho -> sigma` Aesara converter
:math:`mu + sigma*e = mu + log(1+exp(rho))*e`"""
return at.nnet.softplus(rho)
return at.softplus(rho)


rho2sd = rho2sigma
Expand Down
8 changes: 4 additions & 4 deletions pymc3/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class LogExpM1(ElemwiseTransform):
name = "log_exp_m1"

def backward(self, rv_var, rv_value):
return at.nnet.softplus(rv_value)
return at.softplus(rv_value)

def forward(self, rv_var, rv_value):
"""Inverse operation of softplus.
Expand All @@ -160,7 +160,7 @@ def forward(self, rv_var, rv_value):
return at.log(1.0 - at.exp(-rv_value)) + rv_value

def jacobian_det(self, rv_var, rv_value):
return -at.nnet.softplus(-rv_value)
return -at.softplus(-rv_value)


log_exp_m1 = LogExpM1()
Expand Down Expand Up @@ -191,7 +191,7 @@ def backward(self, rv_var, rv_value):
a, b = self.param_extract_fn(rv_var)

if a is not None and b is not None:
sigmoid_x = at.nnet.sigmoid(rv_value)
sigmoid_x = at.sigmoid(rv_value)
return sigmoid_x * b + (1 - sigmoid_x) * a
elif a is not None:
return at.exp(rv_value) + a
Expand All @@ -215,7 +215,7 @@ def jacobian_det(self, rv_var, rv_value):
a, b = self.param_extract_fn(rv_var)

if a is not None and b is not None:
s = at.nnet.softplus(-rv_value)
s = at.softplus(-rv_value)
return at.log(b - a) - 2 * s - rv_value
else:
return rv_value
Expand Down
4 changes: 2 additions & 2 deletions pymc3/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
or_,
prod,
sgn,
sigmoid,
sin,
sinh,
sqr,
Expand All @@ -78,7 +79,6 @@


from aesara.tensor.nlinalg import det, matrix_dot, matrix_inverse, trace
from aesara.tensor.nnet import sigmoid
from scipy.linalg import block_diag as scipy_block_diag

from pymc3.aesaraf import floatX, ix_, largest_common_dtype
Expand Down Expand Up @@ -229,7 +229,7 @@ def log1pexp(x):
This function is numerically more stable than the naive approach.
"""
return at.nnet.softplus(x)
return at.softplus(x)


def log1mexp(x):
Expand Down
6 changes: 3 additions & 3 deletions pymc3/variational/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,15 +390,15 @@ def make_uw(self, u, w):
# u_: d
# w_: d
wu = u.dot(w) # .
mwu = -1.0 + at.nnet.softplus(wu) # .
mwu = -1.0 + at.softplus(wu) # .
# d + (. - .) * d / .
u_h = u + (mwu - wu) * w / ((w ** 2).sum() + 1e-10)
return u_h, w
else:
# u_: bxd
# w_: bxd
wu = (u * w).sum(-1, keepdims=True) # bx-
mwu = -1.0 + at.nnet.softplus(wu) # bx-
mwu = -1.0 + at.softplus(wu) # bx-
# bxd + (bx- - bx-) * bxd / bx- = bxd
u_h = u + (mwu - wu) * w / ((w ** 2).sum(-1, keepdims=True) + 1e-10)
return u_h, w
Expand Down Expand Up @@ -507,7 +507,7 @@ def __init__(self, **kwargs):

def make_ab(self, a, b):
a = at.exp(a)
b = -a + at.nnet.softplus(b)
b = -a + at.softplus(b)
return a, b


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
aesara>=2.0.5
aesara>=2.0.8
arviz>=0.11.2
cachetools>=4.2.1
dill
Expand Down

0 comments on commit de74ff6

Please sign in to comment.