From 5388d61ec215ad3e6eb16d25c5da616f9f4efc07 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Sun, 6 Jun 2021 14:10:34 +0200 Subject: [PATCH] Implement betainc and derivatives --- aesara/scalar/math.py | 222 ++++++++++++++++++++++++++++++++ aesara/tensor/inplace.py | 5 + aesara/tensor/math.py | 6 + tests/scalar/test_math.py | 23 +++- tests/tensor/test_math_scipy.py | 99 ++++++++++++++ 5 files changed, 354 insertions(+), 1 deletion(-) diff --git a/aesara/scalar/math.py b/aesara/scalar/math.py index 1c06681dc7..b8355b62f0 100644 --- a/aesara/scalar/math.py +++ b/aesara/scalar/math.py @@ -5,6 +5,7 @@ """ import os +import warnings import numpy as np import scipy.special @@ -14,12 +15,15 @@ from aesara.gradient import grad_not_implemented from aesara.scalar.basic import ( BinaryScalarOp, + ScalarOp, UnaryScalarOp, complex_types, discrete_types, exp, float64, float_types, + log, + log1p, true_div, upcast, upgrade_to_float, @@ -1044,3 +1048,221 @@ def c_code(self, node, name, inp, out, sub): log1mexp = Log1mexp(upgrade_to_float, name="scalar_log1mexp") + + +class BetaInc(ScalarOp): + """ + Regularized incomplete beta function + """ + + nin = 3 + nfunc_spec = ("scipy.special.betainc", 3, 1) + + def impl(self, a, b, x): + return scipy.special.betainc(a, b, x) + + def grad(self, inp, grads): + a, b, x = inp + (gz,) = grads + + return [ + gz * betainc_der(a, b, x, True), + gz * betainc_der(a, b, x, False), + gz + * exp( + log1p(-x) * (b - 1) + + log(x) * (a - 1) + - (gammaln(a) + gammaln(b) - gammaln(a + b)) + ), + ] + + +betainc = BetaInc(upgrade_to_float_no_complex, name="betainc") + + +class BetaIncDer(ScalarOp): + """ + Gradient of the regularized incomplete beta function wrt to the first + argument (alpha) or the second argument (bbeta), depending on whether the + fourth argument to betainc_der is `True` or `False`, respectively. + + Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function. + Journal of Statistical Software, 3(1), 1-20. + """ + + nin = 4 + + def impl(self, p, q, x, wrtp): + def _betainc_a_n(f, p, q, n): + """ + Numerator (a_n) of the nth approximant of the continued fraction + representation of the regularized incomplete beta function + """ + + if n == 1: + return p * f * (q - 1) / (q * (p + 1)) + + p2n = p + 2 * n + F1 = p ** 2 * f ** 2 * (n - 1) / (q ** 2) + F2 = ( + (p + q + n - 2) + * (p + n - 1) + * (q - n) + / ((p2n - 3) * (p2n - 2) ** 2 * (p2n - 1)) + ) + + return F1 * F2 + + def _betainc_b_n(f, p, q, n): + """ + Offset (b_n) of the nth approximant of the continued fraction + representation of the regularized incomplete beta function + """ + pf = p * f + p2n = p + 2 * n + + N1 = 2 * (pf + 2 * q) * n * (n + p - 1) + p * q * (p - 2 - pf) + D1 = q * (p2n - 2) * p2n + + return N1 / D1 + + def _betainc_da_n_dp(f, p, q, n): + """ + Derivative of a_n wrt p + """ + + if n == 1: + return -p * f * (q - 1) / (q * (p + 1) ** 2) + + pp = p ** 2 + ppp = pp * p + p2n = p + 2 * n + + N1 = -(n - 1) * f ** 2 * pp * (q - n) + N2a = (-8 + 8 * p + 8 * q) * n ** 3 + N2b = (16 * pp + (-44 + 20 * q) * p + 26 - 24 * q) * n ** 2 + N2c = (10 * ppp + (14 * q - 46) * pp + (-40 * q + 66) * p - 28 + 24 * q) * n + N2d = 2 * pp ** 2 + (-13 + 3 * q) * ppp + (-14 * q + 30) * pp + N2e = (-29 + 19 * q) * p + 10 - 8 * q + + D1 = q ** 2 * (p2n - 3) ** 2 + D2 = (p2n - 2) ** 3 * (p2n - 1) ** 2 + + return (N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2 + + def _betainc_da_n_dq(f, p, q, n): + """ + Derivative of a_n wrt q + """ + if n == 1: + return p * f / (q * (p + 1)) + + p2n = p + 2 * n + F1 = (p ** 2 * f ** 2 / (q ** 2)) * (n - 1) * (p + n - 1) * (2 * q + p - 2) + D1 = (p2n - 3) * (p2n - 2) ** 2 * (p2n - 1) + + return F1 / D1 + + def _betainc_db_n_dp(f, p, q, n): + """ + Derivative of b_n wrt p + """ + p2n = p + 2 * n + pp = p ** 2 + q4 = 4 * q + p4 = 4 * p + + F1 = (p * f / q) * ( + (-p4 - q4 + 4) * n ** 2 + (p4 - 4 + q4 - 2 * pp) * n + pp * q + ) + D1 = (p2n - 2) ** 2 * p2n ** 2 + + return F1 / D1 + + def _betainc_db_n_dq(f, p, q, n): + """ + Derivative of b_n wrt to q + """ + p2n = p + 2 * n + return -(p ** 2 * f) / (q * (p2n - 2) * p2n) + + # Input validation + if not (0 <= x <= 1) or p < 0 or q < 0: + return np.nan + + if x > (p / (p + q)): + return -self.impl(q, p, 1 - x, not wrtp) + + min_iters = 3 + max_iters = 200 + err_threshold = 1e-12 + + derivative_old = 0 + + Am2, Am1 = 1, 1 + Bm2, Bm1 = 0, 1 + dAm2, dAm1 = 0, 0 + dBm2, dBm1 = 0, 0 + + f = (q * x) / (p * (1 - x)) + K = np.exp( + p * np.log(x) + + (q - 1) * np.log1p(-x) + - np.log(p) + - scipy.special.betaln(p, q) + ) + if wrtp: + dK = ( + np.log(x) + - 1 / p + + scipy.special.digamma(p + q) + - scipy.special.digamma(p) + ) + else: + dK = np.log1p(-x) + scipy.special.digamma(p + q) - scipy.special.digamma(q) + + for n in range(1, max_iters + 1): + a_n_ = _betainc_a_n(f, p, q, n) + b_n_ = _betainc_b_n(f, p, q, n) + if wrtp: + da_n = _betainc_da_n_dp(f, p, q, n) + db_n = _betainc_db_n_dp(f, p, q, n) + else: + da_n = _betainc_da_n_dq(f, p, q, n) + db_n = _betainc_db_n_dq(f, p, q, n) + + A = a_n_ * Am2 + b_n_ * Am1 + B = a_n_ * Bm2 + b_n_ * Bm1 + dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1 + dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1 + + Am2, Am1 = Am1, A + Bm2, Bm1 = Bm1, B + dAm2, dAm1 = dAm1, dA + dBm2, dBm1 = dBm1, dB + + if n < min_iters - 1: + continue + + F1 = A / B + F2 = (dA - F1 * dB) / B + derivative = K * (F1 * dK + F2) + + errapx = abs(derivative_old - derivative) + d_errapx = errapx / max(err_threshold, abs(derivative)) + derivative_old = derivative + + if d_errapx <= err_threshold: + break + + if n >= max_iters: + warnings.warn( + f"_betainc_derivative did not converge after {n} iterations", + RuntimeWarning, + ) + return np.nan + + return derivative + + +betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der") diff --git a/aesara/tensor/inplace.py b/aesara/tensor/inplace.py index e2814f6e9f..eb0da2a55e 100644 --- a/aesara/tensor/inplace.py +++ b/aesara/tensor/inplace.py @@ -323,6 +323,11 @@ def log1mexp_inplace(x): """Compute log(1 - exp(x)), also known as log1mexp""" +@scalar_elemwise +def betainc_inplace(a, b, x): + """Regularized incomplete beta function""" + + @scalar_elemwise def second_inplace(a): """Fill `a` with `b`""" diff --git a/aesara/tensor/math.py b/aesara/tensor/math.py index a621801075..f74986ca13 100644 --- a/aesara/tensor/math.py +++ b/aesara/tensor/math.py @@ -1429,6 +1429,11 @@ def log1mexp(x): """Compute log(1 - exp(x)), also known as log1mexp""" +@scalar_elemwise +def betainc(a, b, x): + """Regularized incomplete beta function""" + + @scalar_elemwise def real(z): """Return real component of complex-valued tensor `z`""" @@ -2909,6 +2914,7 @@ def logsumexp(x, axis=None, keepdims=False): "softplus", "log1pexp", "log1mexp", + "betainc", "real", "imag", "angle", diff --git a/tests/scalar/test_math.py b/tests/scalar/test_math.py index 08bd51fdba..38bd198677 100644 --- a/tests/scalar/test_math.py +++ b/tests/scalar/test_math.py @@ -1,9 +1,12 @@ import numpy as np +import scipy.special as sp import aesara.tensor as aet +from aesara import function +from aesara.compile.mode import Mode from aesara.graph.fg import FunctionGraph from aesara.link.c.basic import CLinker -from aesara.scalar.math import gammainc, gammaincc, gammal, gammau +from aesara.scalar.math import betainc, betainc_der, gammainc, gammaincc, gammal, gammau def test_gammainc_nan(): @@ -44,3 +47,21 @@ def test_gammau_nan(): assert np.isnan(test_func(-1, 1)) assert np.isnan(test_func(1, -1)) assert np.isnan(test_func(-1, -1)) + + +def test_betainc(): + a, b, x = aet.scalars("a", "b", "x") + res = betainc(a, b, x) + test_func = function([a, b, x], res, mode=Mode("py")) + assert np.isclose(test_func(15, 10, 0.7), sp.betainc(15, 10, 0.7)) + + +def test_betainc_derivative_nan(): + a, b, x = aet.scalars("a", "b", "x") + res = betainc_der(a, b, x, True) + test_func = function([a, b, x], res, mode=Mode("py")) + assert not np.isnan(test_func(1, 1, 1)) + assert np.isnan(test_func(1, 1, -1)) + assert np.isnan(test_func(1, 1, 2)) + assert np.isnan(test_func(1, -1, 1)) + assert np.isnan(test_func(1, 1, -1)) diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index b5e1052142..96b1c469e5 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -9,6 +9,7 @@ import scipy.special import scipy.stats +from aesara import function from aesara import tensor as aet from aesara.compile.mode import get_default_mode from aesara.configdefaults import config @@ -603,3 +604,101 @@ def expected_log1mexp(x): def test_deprecated_module(): with pytest.warns(DeprecationWarning): import aesara.scalar.basic_scipy # noqa: F401 + + +_good_broadcast_ternary_betainc = dict( + normal=( + random_ranged(0, 1000, (2, 3)), + random_ranged(0, 1000, (2, 3)), + random_ranged(0, 1, (2, 3)), + ), +) + +TestBetaincBroadcast = makeBroadcastTester( + op=aet.betainc, + expected=scipy.special.betainc, + good=_good_broadcast_ternary_betainc, + grad=_good_broadcast_ternary_betainc, +) + +TestBetaincInplaceBroadcast = makeBroadcastTester( + op=inplace.betainc_inplace, + expected=scipy.special.betainc, + good=_good_broadcast_ternary_betainc, + grad=_good_broadcast_ternary_betainc, + inplace=True, +) + + +class TestBetaIncGrad: + def test_stan_grad_partial(self): + # This test combines the following STAN tests: + # https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_dda_test.cpp + # https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_ddb_test.cpp + # https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_ddz_test.cpp + a, b, z = aet.scalars("a", "b", "z") + betainc_out = aet.betainc(a, b, z) + betainc_grad = aet.grad(betainc_out, [a, b, z]) + f_grad = function([a, b, z], betainc_grad) + + decimal_precision = 7 if config.floatX == "float64" else 3 + + for test_a, test_b, test_z, expected_dda, expected_ddb, expected_ddz in ( + (1.5, 1.25, 0.001, -0.00028665637, 4.41357328e-05, 0.063300692), + (1.5, 1.25, 0.5, -0.26038693947, 0.29301795, 1.1905416), + (1.5, 1.25, 0.6, -0.23806757, 0.32279575, 1.23341068), + (1.5, 1.25, 0.999, -0.00022264493, 0.0018969609, 0.35587692), + (15000, 1.25, 0.001, 0, 0, 0), + (15000, 1.25, 0.5, 0, 0, 0), + (15000, 1.25, 0.6, 0, 0, 0), + (15000, 1.25, 0.999, -6.59543226e-10, 2.00849793e-06, 0.009898182), + (1.5, 12500, 0.001, -3.93756641e-05, 1.47821755e-09, 0.1848717), + (1.5, 12500, 0.5, 0, 0, 0), + (1.5, 12500, 0.6, 0, 0, 0), + (1.5, 12500, 0.999, 0, 0, 0), + (15000, 12500, 0.001, 0, 0, 0), + (15000, 12500, 0.5, -8.72102443e-53, 9.55282792e-53, 5.01131256e-48), + (15000, 12500, 0.6, -4.085621e-14, -5.5067062e-14, 1.15135267e-71), + (15000, 12500, 0.999, 0, 0, 0), + ): + np.testing.assert_almost_equal( + f_grad(test_a, test_b, test_z), + [expected_dda, expected_ddb, expected_ddz], + decimal=decimal_precision, + ) + + def test_boik_robison_cox(self): + # This test compares against the tabulated values in: + # Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function. + # Journal of Statistical Software, 3(1), 1-20. + a, b, z = aet.scalars("a", "b", "z") + betainc_out = aet.betainc(a, b, z) + betainc_grad = aet.grad(betainc_out, [a, b]) + f_grad = function([a, b, z], betainc_grad) + + for test_a, test_b, test_z, expected_dda, expected_ddb in ( + (1.5, 11.0, 0.001, -4.5720356e-03, 1.1845673e-04), + (1.5, 11.0, 0.5, -2.5501997e-03, 9.0824388e-04), + (1000.0, 1000.0, 0.5, -8.9224793e-03, 8.9224793e-03), + (1000.0, 1000.0, 0.55, -3.6713108e-07, 4.0584118e-07), + ): + np.testing.assert_almost_equal( + f_grad(test_a, test_b, test_z), + [expected_dda, expected_ddb], + ) + + def test_beta_inc_stan_grad_combined(self): + # This test replicates the following STAN test: + # https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_reg_inc_beta_test.cpp + a, b, z = aet.scalars("a", "b", "z") + betainc_out = aet.betainc(a, b, z) + betainc_grad = aet.grad(betainc_out, [a, b]) + f_grad = function([a, b, z], betainc_grad) + + for test_a, test_b, test_z, expected_dda, expected_ddb in ( + (1.0, 1.0, 1.0, 0, np.nan), + (1.0, 1.0, 0.4, -0.36651629, 0.30649537), + ): + np.testing.assert_allclose( + f_grad(test_a, test_b, test_z), [expected_dda, expected_ddb] + )