Skip to content

Commit

Permalink
Implement betainc and derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Jun 7, 2021
1 parent be0ea5c commit 91d36ac
Show file tree
Hide file tree
Showing 6 changed files with 340 additions and 1 deletion.
4 changes: 4 additions & 0 deletions aesara/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,10 @@ class BinaryScalarOp(ScalarOp):
nin = 2


class TernaryScalarOp(ScalarOp):
nin = 3


class LogicalComparison(BinaryScalarOp):
def __init__(self, *args, **kwargs):
BinaryScalarOp.__init__(self, *args, **kwargs)
Expand Down
224 changes: 224 additions & 0 deletions aesara/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
from aesara.gradient import grad_not_implemented
from aesara.scalar.basic import (
BinaryScalarOp,
TernaryScalarOp,
UnaryScalarOp,
complex_types,
discrete_types,
exp,
float64,
float_types,
log,
log1p,
upcast,
upgrade_to_float,
upgrade_to_float64,
Expand Down Expand Up @@ -1070,3 +1073,224 @@ def c_code_cache_version(self):


softplus = Softplus(upgrade_to_float, name="scalar_softplus")


class BetaInc(TernaryScalarOp):
"""
Regularized incomplete beta function
"""

nfunc_spec = ("scipy.special.betainc", 3, 1)

def impl(self, a, b, x):
return scipy.special.betainc(a, b, x)

def grad(self, inp, grads):
a, b, x = inp
(gz,) = grads

return [
gz * betainc_dda_scalar(a, b, x),
gz * betainc_ddb_scalar(a, b, x),
gz
* exp(
log1p(-x) * (b - 1)
+ log(x) * (a - 1)
- (gammaln(a) + gammaln(b) - gammaln(a + b))
),
]


betainc = BetaInc(upgrade_to_float_no_complex, name="betainc")


class BetaIncDda(TernaryScalarOp):
"""
Gradient of the regularized incomplete beta function wrt to the first argument (a)
"""

def impl(self, a, b, x):
return _betainc_derivative(a, b, x, wrtp=True)


betainc_dda_scalar = BetaIncDda(upgrade_to_float_no_complex, name="betainc_dda")


class BetaIncDdb(TernaryScalarOp):
"""
Gradient of the regularized incomplete beta function wrt to the second argument (b)
"""

def impl(self, a, b, x):
return _betainc_derivative(a, b, x, wrtp=False)


betainc_ddb_scalar = BetaIncDdb(upgrade_to_float_no_complex, name="betainc_ddb")


def _betainc_derivative(p, q, x, wrtp=True):
"""
Compute the derivative of regularized incomplete beta function wrt to p (alpha) or q (beta)
Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function.
Journal of Statistical Software, 3(1), 1-20.
"""

def _betainc_a_n(f, p, q, n):
"""
Numerator (a_n) of the nth approximant of the continued fraction
representation of the regularized incomplete beta function
"""

if n == 1:
return p * f * (q - 1) / (q * (p + 1))

p2n = p + 2 * n
F1 = p ** 2 * f ** 2 * (n - 1) / (q ** 2)
F2 = (
(p + q + n - 2)
* (p + n - 1)
* (q - n)
/ ((p2n - 3) * (p2n - 2) ** 2 * (p2n - 1))
)

return F1 * F2

def _betainc_b_n(f, p, q, n):
"""
Offset (b_n) of the nth approximant of the continued fraction
representation of the regularized incomplete beta function
"""
pf = p * f
p2n = p + 2 * n

N1 = 2 * (pf + 2 * q) * n * (n + p - 1) + p * q * (p - 2 - pf)
D1 = q * (p2n - 2) * p2n

return N1 / D1

def _betainc_da_n_dp(f, p, q, n):
"""
Derivative of a_n wrt p
"""

if n == 1:
return -p * f * (q - 1) / (q * (p + 1) ** 2)

pp = p ** 2
ppp = pp * p
p2n = p + 2 * n

N1 = -(n - 1) * f ** 2 * pp * (q - n)
N2a = (-8 + 8 * p + 8 * q) * n ** 3
N2b = (16 * pp + (-44 + 20 * q) * p + 26 - 24 * q) * n ** 2
N2c = (10 * ppp + (14 * q - 46) * pp + (-40 * q + 66) * p - 28 + 24 * q) * n
N2d = 2 * pp ** 2 + (-13 + 3 * q) * ppp + (-14 * q + 30) * pp
N2e = (-29 + 19 * q) * p + 10 - 8 * q

D1 = q ** 2 * (p2n - 3) ** 2
D2 = (p2n - 2) ** 3 * (p2n - 1) ** 2

return (N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2

def _betainc_da_n_dq(f, p, q, n):
"""
Derivative of a_n wrt q
"""
if n == 1:
return p * f / (q * (p + 1))

p2n = p + 2 * n
F1 = (p ** 2 * f ** 2 / (q ** 2)) * (n - 1) * (p + n - 1) * (2 * q + p - 2)
D1 = (p2n - 3) * (p2n - 2) ** 2 * (p2n - 1)

return F1 / D1

def _betainc_db_n_dp(f, p, q, n):
"""
Derivative of b_n wrt p
"""
p2n = p + 2 * n
pp = p ** 2
q4 = 4 * q
p4 = 4 * p

F1 = (p * f / q) * (
(-p4 - q4 + 4) * n ** 2 + (p4 - 4 + q4 - 2 * pp) * n + pp * q
)
D1 = (p2n - 2) ** 2 * p2n ** 2

return F1 / D1

def _betainc_db_n_dq(f, p, q, n):
"""
Derivative of b_n wrt to q
"""
p2n = p + 2 * n
return -(p ** 2 * f) / (q * (p2n - 2) * p2n)

# Input validation
if not (0 <= x <= 1) or p < 0 or q < 0:
return np.nan

if x > (p / (p + q)):
return -_betainc_derivative(q, p, 1 - x, not wrtp)

min_iters = 3
max_iters = 200
err_threshold = 1e-12

derivative_old = 0

Am2, Am1 = 1, 1
Bm2, Bm1 = 0, 1
dAm2, dAm1 = 0, 0
dBm2, dBm1 = 0, 0

f = (q * x) / (p * (1 - x))
K = np.exp(
p * np.log(x) + (q - 1) * np.log1p(-x) - np.log(p) - scipy.special.betaln(p, q)
)
if wrtp:
dK = np.log(x) - 1 / p + scipy.special.digamma(p + q) - scipy.special.digamma(p)
else:
dK = np.log1p(-x) + scipy.special.digamma(p + q) - scipy.special.digamma(q)

for n in range(1, max_iters + 1):
a_n_ = _betainc_a_n(f, p, q, n)
b_n_ = _betainc_b_n(f, p, q, n)
if wrtp:
da_n = _betainc_da_n_dp(f, p, q, n)
db_n = _betainc_db_n_dp(f, p, q, n)
else:
da_n = _betainc_da_n_dq(f, p, q, n)
db_n = _betainc_db_n_dq(f, p, q, n)

A = a_n_ * Am2 + b_n_ * Am1
B = a_n_ * Bm2 + b_n_ * Bm1
dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1
dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1

Am2, Am1 = Am1, A
Bm2, Bm1 = Bm1, B
dAm2, dAm1 = dAm1, dA
dBm2, dBm1 = dBm1, dB

if n < min_iters - 1:
continue

F1 = A / B
F2 = (dA - F1 * dB) / B
derivative = K * (F1 * dK + F2)

errapx = abs(derivative_old - derivative)
d_errapx = errapx / max(err_threshold, abs(derivative))
derivative_old = derivative

if d_errapx <= err_threshold:
break

if n >= max_iters:
return np.nan

return derivative
5 changes: 5 additions & 0 deletions aesara/tensor/inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ def softplus_inplace(x):
"""Compute log(1 + exp(x)), also known as softplus or log1pexp"""


@scalar_elemwise
def betainc_inplace(a, b, x):
"""Regularized incomplete beta function"""


@scalar_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
Expand Down
6 changes: 6 additions & 0 deletions aesara/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,11 @@ def softplus(x):
"""Compute log(1 + exp(x)), also known as softplus or log1pexp"""


@scalar_elemwise
def betainc(a, b, x):
"""Regularized incomplete beta function"""


@scalar_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`"""
Expand Down Expand Up @@ -2847,6 +2852,7 @@ def power(x, y):
"sigmoid",
"expit",
"softplus",
"betainc",
"real",
"imag",
"angle",
Expand Down
78 changes: 77 additions & 1 deletion tests/scalar/test_math.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
from numpy.testing import assert_allclose, assert_almost_equal

import aesara.tensor as aet
from aesara import function
from aesara.graph.fg import FunctionGraph
from aesara.link.c.basic import CLinker
from aesara.scalar.math import gammainc, gammaincc, gammal, gammau
from aesara.scalar.math import betainc, gammainc, gammaincc, gammal, gammau


def test_gammainc_nan():
Expand Down Expand Up @@ -44,3 +46,77 @@ def test_gammau_nan():
assert np.isnan(test_func(-1, 1))
assert np.isnan(test_func(1, -1))
assert np.isnan(test_func(-1, -1))


class TestBetaIncGrad:

# This test replicates the following STAN test:
# https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_reg_inc_beta_test.cpp
def test_stan_grad_combined(self):
a, b, z = aet.scalars("a", "b", "z")
betainc_out = betainc(a, b, z)
betainc_grad = aet.grad(betainc_out, [a, b], null_gradients="return")
f_grad = function([a, b, z], betainc_grad)

for test_a, test_b, test_z, expected_dda, expected_ddb in (
(1.0, 1.0, 1.0, 0, np.nan),
(1.0, 1.0, 0.4, -0.36651629, 0.30649537),
):
assert_allclose(
f_grad(test_a, test_b, test_z), [expected_dda, expected_ddb]
)

# This test combines the following STAN tests:
# https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_dda_test.cpp
# https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_ddb_test.cpp
# https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_ddz_test.cpp
def test_stan_grad_partial(self):
a, b, z = aet.scalars("a", "b", "z")
betainc_out = betainc(a, b, z)
betainc_grad = aet.grad(betainc_out, [a, b, z])
f_grad = function([a, b, z], betainc_grad)

for test_a, test_b, test_z, expected_dda, expected_ddb, expected_ddz in (
(1.5, 1.25, 0.001, -0.00028665637, 4.41357328e-05, 0.063300692),
(1.5, 1.25, 0.5, -0.26038693947, 0.29301795, 1.1905416),
(1.5, 1.25, 0.6, -0.23806757, 0.32279575, 1.23341068),
(1.5, 1.25, 0.999, -0.00022264493, 0.0018969609, 0.35587692),
(15000, 1.25, 0.001, 0, 0, 0),
(15000, 1.25, 0.5, 0, 0, 0),
(15000, 1.25, 0.6, 0, 0, 0),
(15000, 1.25, 0.999, -6.59543226e-10, 2.00849793e-06, 0.009898182),
(1.5, 12500, 0.001, -3.93756641e-05, 1.47821755e-09, 0.1848717),
(1.5, 12500, 0.5, 0, 0, 0),
(1.5, 12500, 0.6, 0, 0, 0),
(1.5, 12500, 0.999, 0, 0, 0),
(15000, 12500, 0.001, 0, 0, 0),
(15000, 12500, 0.5, -8.72102443e-53, 9.55282792e-53, 5.01131256e-48),
(15000, 12500, 0.6, -4.085621e-14, -5.5067062e-14, 1.15135267e-71),
(15000, 12500, 0.999, 0, 0, 0),
):

assert_almost_equal(
f_grad(test_a, test_b, test_z),
[expected_dda, expected_ddb, expected_ddz],
decimal=4,
)

# This test compares against the tabulated values in:
# Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function.
# Journal of Statistical Software, 3(1), 1-20.
def test_boik_robison_cox(self):
a, b, z = aet.scalars("a", "b", "z")
betainc_out = betainc(a, b, z)
betainc_grad = aet.grad(betainc_out, [a, b])
f_grad = function([a, b, z], betainc_grad)

for test_a, test_b, test_z, expected_dda, expected_ddb in (
(1.5, 11.0, 0.001, -4.5720356e-03, 1.1845673e-04),
(1.5, 11.0, 0.5, -2.5501997e-03, 9.0824388e-04),
(1000.0, 1000.0, 0.5, -8.9224793e-03, 8.9224793e-03),
(1000.0, 1000.0, 0.55, -3.6713108e-07, 4.0584118e-07),
):
assert_almost_equal(
f_grad(test_a, test_b, test_z),
[expected_dda, expected_ddb],
)
Loading

0 comments on commit 91d36ac

Please sign in to comment.