Skip to content

Commit

Permalink
Implement betainc and derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Jul 5, 2021
1 parent b5313f1 commit cca785f
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 1 deletion.
222 changes: 222 additions & 0 deletions aesara/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import os
import warnings

import numpy as np
import scipy.special
Expand All @@ -14,12 +15,15 @@
from aesara.gradient import grad_not_implemented
from aesara.scalar.basic import (
BinaryScalarOp,
ScalarOp,
UnaryScalarOp,
complex_types,
discrete_types,
exp,
float64,
float_types,
log,
log1p,
true_div,
upcast,
upgrade_to_float,
Expand Down Expand Up @@ -1044,3 +1048,221 @@ def c_code(self, node, name, inp, out, sub):


log1mexp = Log1mexp(upgrade_to_float, name="scalar_log1mexp")


class BetaInc(ScalarOp):
"""
Regularized incomplete beta function
"""

nin = 3
nfunc_spec = ("scipy.special.betainc", 3, 1)

def impl(self, a, b, x):
return scipy.special.betainc(a, b, x)

def grad(self, inp, grads):
a, b, x = inp
(gz,) = grads

return [
gz * betainc_der(a, b, x, True),
gz * betainc_der(a, b, x, False),
gz
* exp(
log1p(-x) * (b - 1)
+ log(x) * (a - 1)
- (gammaln(a) + gammaln(b) - gammaln(a + b))
),
]


betainc = BetaInc(upgrade_to_float_no_complex, name="betainc")


class BetaIncDer(ScalarOp):
"""
Gradient of the regularized incomplete beta function wrt to the first
argument (alpha) or the second argument (bbeta), depending on whether the
fourth argument to betainc_der is `True` or `False`, respectively.
"""

nin = 4

def impl(self, a, b, x, wrta):
return _betainc_derivative(a, b, x, wrta)


betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")


def _betainc_derivative(p, q, x, wrtp):
"""
Compute the derivative of regularized incomplete beta function wrt to p (alpha) or q (beta)
Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function.
Journal of Statistical Software, 3(1), 1-20.
"""

def _betainc_a_n(f, p, q, n):
"""
Numerator (a_n) of the nth approximant of the continued fraction
representation of the regularized incomplete beta function
"""

if n == 1:
return p * f * (q - 1) / (q * (p + 1))

p2n = p + 2 * n
F1 = p ** 2 * f ** 2 * (n - 1) / (q ** 2)
F2 = (
(p + q + n - 2)
* (p + n - 1)
* (q - n)
/ ((p2n - 3) * (p2n - 2) ** 2 * (p2n - 1))
)

return F1 * F2

def _betainc_b_n(f, p, q, n):
"""
Offset (b_n) of the nth approximant of the continued fraction
representation of the regularized incomplete beta function
"""
pf = p * f
p2n = p + 2 * n

N1 = 2 * (pf + 2 * q) * n * (n + p - 1) + p * q * (p - 2 - pf)
D1 = q * (p2n - 2) * p2n

return N1 / D1

def _betainc_da_n_dp(f, p, q, n):
"""
Derivative of a_n wrt p
"""

if n == 1:
return -p * f * (q - 1) / (q * (p + 1) ** 2)

pp = p ** 2
ppp = pp * p
p2n = p + 2 * n

N1 = -(n - 1) * f ** 2 * pp * (q - n)
N2a = (-8 + 8 * p + 8 * q) * n ** 3
N2b = (16 * pp + (-44 + 20 * q) * p + 26 - 24 * q) * n ** 2
N2c = (10 * ppp + (14 * q - 46) * pp + (-40 * q + 66) * p - 28 + 24 * q) * n
N2d = 2 * pp ** 2 + (-13 + 3 * q) * ppp + (-14 * q + 30) * pp
N2e = (-29 + 19 * q) * p + 10 - 8 * q

D1 = q ** 2 * (p2n - 3) ** 2
D2 = (p2n - 2) ** 3 * (p2n - 1) ** 2

return (N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2

def _betainc_da_n_dq(f, p, q, n):
"""
Derivative of a_n wrt q
"""
if n == 1:
return p * f / (q * (p + 1))

p2n = p + 2 * n
F1 = (p ** 2 * f ** 2 / (q ** 2)) * (n - 1) * (p + n - 1) * (2 * q + p - 2)
D1 = (p2n - 3) * (p2n - 2) ** 2 * (p2n - 1)

return F1 / D1

def _betainc_db_n_dp(f, p, q, n):
"""
Derivative of b_n wrt p
"""
p2n = p + 2 * n
pp = p ** 2
q4 = 4 * q
p4 = 4 * p

F1 = (p * f / q) * (
(-p4 - q4 + 4) * n ** 2 + (p4 - 4 + q4 - 2 * pp) * n + pp * q
)
D1 = (p2n - 2) ** 2 * p2n ** 2

return F1 / D1

def _betainc_db_n_dq(f, p, q, n):
"""
Derivative of b_n wrt to q
"""
p2n = p + 2 * n
return -(p ** 2 * f) / (q * (p2n - 2) * p2n)

# Input validation
if not (0 <= x <= 1) or p < 0 or q < 0:
return np.nan

if x > (p / (p + q)):
return -_betainc_derivative(q, p, 1 - x, not wrtp)

min_iters = 3
max_iters = 200
err_threshold = 1e-12

derivative_old = 0

Am2, Am1 = 1, 1
Bm2, Bm1 = 0, 1
dAm2, dAm1 = 0, 0
dBm2, dBm1 = 0, 0

f = (q * x) / (p * (1 - x))
K = np.exp(
p * np.log(x) + (q - 1) * np.log1p(-x) - np.log(p) - scipy.special.betaln(p, q)
)
if wrtp:
dK = np.log(x) - 1 / p + scipy.special.digamma(p + q) - scipy.special.digamma(p)
else:
dK = np.log1p(-x) + scipy.special.digamma(p + q) - scipy.special.digamma(q)

for n in range(1, max_iters + 1):
a_n_ = _betainc_a_n(f, p, q, n)
b_n_ = _betainc_b_n(f, p, q, n)
if wrtp:
da_n = _betainc_da_n_dp(f, p, q, n)
db_n = _betainc_db_n_dp(f, p, q, n)
else:
da_n = _betainc_da_n_dq(f, p, q, n)
db_n = _betainc_db_n_dq(f, p, q, n)

A = a_n_ * Am2 + b_n_ * Am1
B = a_n_ * Bm2 + b_n_ * Bm1
dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1
dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1

Am2, Am1 = Am1, A
Bm2, Bm1 = Bm1, B
dAm2, dAm1 = dAm1, dA
dBm2, dBm1 = dBm1, dB

if n < min_iters - 1:
continue

F1 = A / B
F2 = (dA - F1 * dB) / B
derivative = K * (F1 * dK + F2)

errapx = abs(derivative_old - derivative)
d_errapx = errapx / max(err_threshold, abs(derivative))
derivative_old = derivative

if d_errapx <= err_threshold:
break

if n >= max_iters:
warnings.warn(
f"_betainc_derivative did not converge after {n} iterations",
RuntimeWarning,
)
return np.nan

return derivative
5 changes: 5 additions & 0 deletions aesara/tensor/inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ def log1mexp_inplace(x):
"""Compute log(1 - exp(x)), also known as log1mexp"""


@scalar_elemwise
def betainc_inplace(a, b, x):
"""Regularized incomplete beta function"""


@scalar_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
Expand Down
6 changes: 6 additions & 0 deletions aesara/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,11 @@ def log1mexp(x):
"""Compute log(1 - exp(x)), also known as log1mexp"""


@scalar_elemwise
def betainc(a, b, x):
"""Regularized incomplete beta function"""


@scalar_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`"""
Expand Down Expand Up @@ -2909,6 +2914,7 @@ def logsumexp(x, axis=None, keepdims=False):
"softplus",
"log1pexp",
"log1mexp",
"betainc",
"real",
"imag",
"angle",
Expand Down
23 changes: 22 additions & 1 deletion tests/scalar/test_math.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import numpy as np
import scipy.special as sp

import aesara.tensor as aet
from aesara import function
from aesara.compile.mode import Mode
from aesara.graph.fg import FunctionGraph
from aesara.link.c.basic import CLinker
from aesara.scalar.math import gammainc, gammaincc, gammal, gammau
from aesara.scalar.math import betainc, betainc_der, gammainc, gammaincc, gammal, gammau


def test_gammainc_nan():
Expand Down Expand Up @@ -44,3 +47,21 @@ def test_gammau_nan():
assert np.isnan(test_func(-1, 1))
assert np.isnan(test_func(1, -1))
assert np.isnan(test_func(-1, -1))


def test_betainc():
a, b, x = aet.scalars("a", "b", "x")
res = betainc(a, b, x)
test_func = function([a, b, x], res, mode=Mode("py"))
assert np.isclose(test_func(15, 10, 0.7), sp.betainc(15, 10, 0.7))


def test_betainc_derivative_nan():
a, b, x = aet.scalars("a", "b", "x")
res = betainc_der(a, b, x, True)
test_func = function([a, b, x], res, mode=Mode("py"))
assert not np.isnan(test_func(1, 1, 1))
assert np.isnan(test_func(1, 1, -1))
assert np.isnan(test_func(1, 1, 2))
assert np.isnan(test_func(1, -1, 1))
assert np.isnan(test_func(1, 1, -1))
Loading

0 comments on commit cca785f

Please sign in to comment.