From 91d36ac6d5129af6b68ad55bb1bb68daed80105c Mon Sep 17 00:00:00 2001 From: Ricardo Date: Sun, 6 Jun 2021 14:10:34 +0200 Subject: [PATCH] Implement betainc and derivatives --- aesara/scalar/basic.py | 4 + aesara/scalar/math.py | 224 ++++++++++++++++++++++++++++++++ aesara/tensor/inplace.py | 5 + aesara/tensor/math.py | 6 + tests/scalar/test_math.py | 78 ++++++++++- tests/tensor/test_math_scipy.py | 24 ++++ 6 files changed, 340 insertions(+), 1 deletion(-) diff --git a/aesara/scalar/basic.py b/aesara/scalar/basic.py index 206e62f000..de0dce3ed0 100644 --- a/aesara/scalar/basic.py +++ b/aesara/scalar/basic.py @@ -1276,6 +1276,10 @@ class BinaryScalarOp(ScalarOp): nin = 2 +class TernaryScalarOp(ScalarOp): + nin = 3 + + class LogicalComparison(BinaryScalarOp): def __init__(self, *args, **kwargs): BinaryScalarOp.__init__(self, *args, **kwargs) diff --git a/aesara/scalar/math.py b/aesara/scalar/math.py index 2a423fbbe9..9cda67ee00 100644 --- a/aesara/scalar/math.py +++ b/aesara/scalar/math.py @@ -12,12 +12,15 @@ from aesara.gradient import grad_not_implemented from aesara.scalar.basic import ( BinaryScalarOp, + TernaryScalarOp, UnaryScalarOp, complex_types, discrete_types, exp, float64, float_types, + log, + log1p, upcast, upgrade_to_float, upgrade_to_float64, @@ -1070,3 +1073,224 @@ def c_code_cache_version(self): softplus = Softplus(upgrade_to_float, name="scalar_softplus") + + +class BetaInc(TernaryScalarOp): + """ + Regularized incomplete beta function + """ + + nfunc_spec = ("scipy.special.betainc", 3, 1) + + def impl(self, a, b, x): + return scipy.special.betainc(a, b, x) + + def grad(self, inp, grads): + a, b, x = inp + (gz,) = grads + + return [ + gz * betainc_dda_scalar(a, b, x), + gz * betainc_ddb_scalar(a, b, x), + gz + * exp( + log1p(-x) * (b - 1) + + log(x) * (a - 1) + - (gammaln(a) + gammaln(b) - gammaln(a + b)) + ), + ] + + +betainc = BetaInc(upgrade_to_float_no_complex, name="betainc") + + +class BetaIncDda(TernaryScalarOp): + """ + Gradient of the regularized incomplete beta function wrt to the first argument (a) + """ + + def impl(self, a, b, x): + return _betainc_derivative(a, b, x, wrtp=True) + + +betainc_dda_scalar = BetaIncDda(upgrade_to_float_no_complex, name="betainc_dda") + + +class BetaIncDdb(TernaryScalarOp): + """ + Gradient of the regularized incomplete beta function wrt to the second argument (b) + """ + + def impl(self, a, b, x): + return _betainc_derivative(a, b, x, wrtp=False) + + +betainc_ddb_scalar = BetaIncDdb(upgrade_to_float_no_complex, name="betainc_ddb") + + +def _betainc_derivative(p, q, x, wrtp=True): + """ + Compute the derivative of regularized incomplete beta function wrt to p (alpha) or q (beta) + + Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function. + Journal of Statistical Software, 3(1), 1-20. + """ + + def _betainc_a_n(f, p, q, n): + """ + Numerator (a_n) of the nth approximant of the continued fraction + representation of the regularized incomplete beta function + """ + + if n == 1: + return p * f * (q - 1) / (q * (p + 1)) + + p2n = p + 2 * n + F1 = p ** 2 * f ** 2 * (n - 1) / (q ** 2) + F2 = ( + (p + q + n - 2) + * (p + n - 1) + * (q - n) + / ((p2n - 3) * (p2n - 2) ** 2 * (p2n - 1)) + ) + + return F1 * F2 + + def _betainc_b_n(f, p, q, n): + """ + Offset (b_n) of the nth approximant of the continued fraction + representation of the regularized incomplete beta function + """ + pf = p * f + p2n = p + 2 * n + + N1 = 2 * (pf + 2 * q) * n * (n + p - 1) + p * q * (p - 2 - pf) + D1 = q * (p2n - 2) * p2n + + return N1 / D1 + + def _betainc_da_n_dp(f, p, q, n): + """ + Derivative of a_n wrt p + """ + + if n == 1: + return -p * f * (q - 1) / (q * (p + 1) ** 2) + + pp = p ** 2 + ppp = pp * p + p2n = p + 2 * n + + N1 = -(n - 1) * f ** 2 * pp * (q - n) + N2a = (-8 + 8 * p + 8 * q) * n ** 3 + N2b = (16 * pp + (-44 + 20 * q) * p + 26 - 24 * q) * n ** 2 + N2c = (10 * ppp + (14 * q - 46) * pp + (-40 * q + 66) * p - 28 + 24 * q) * n + N2d = 2 * pp ** 2 + (-13 + 3 * q) * ppp + (-14 * q + 30) * pp + N2e = (-29 + 19 * q) * p + 10 - 8 * q + + D1 = q ** 2 * (p2n - 3) ** 2 + D2 = (p2n - 2) ** 3 * (p2n - 1) ** 2 + + return (N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2 + + def _betainc_da_n_dq(f, p, q, n): + """ + Derivative of a_n wrt q + """ + if n == 1: + return p * f / (q * (p + 1)) + + p2n = p + 2 * n + F1 = (p ** 2 * f ** 2 / (q ** 2)) * (n - 1) * (p + n - 1) * (2 * q + p - 2) + D1 = (p2n - 3) * (p2n - 2) ** 2 * (p2n - 1) + + return F1 / D1 + + def _betainc_db_n_dp(f, p, q, n): + """ + Derivative of b_n wrt p + """ + p2n = p + 2 * n + pp = p ** 2 + q4 = 4 * q + p4 = 4 * p + + F1 = (p * f / q) * ( + (-p4 - q4 + 4) * n ** 2 + (p4 - 4 + q4 - 2 * pp) * n + pp * q + ) + D1 = (p2n - 2) ** 2 * p2n ** 2 + + return F1 / D1 + + def _betainc_db_n_dq(f, p, q, n): + """ + Derivative of b_n wrt to q + """ + p2n = p + 2 * n + return -(p ** 2 * f) / (q * (p2n - 2) * p2n) + + # Input validation + if not (0 <= x <= 1) or p < 0 or q < 0: + return np.nan + + if x > (p / (p + q)): + return -_betainc_derivative(q, p, 1 - x, not wrtp) + + min_iters = 3 + max_iters = 200 + err_threshold = 1e-12 + + derivative_old = 0 + + Am2, Am1 = 1, 1 + Bm2, Bm1 = 0, 1 + dAm2, dAm1 = 0, 0 + dBm2, dBm1 = 0, 0 + + f = (q * x) / (p * (1 - x)) + K = np.exp( + p * np.log(x) + (q - 1) * np.log1p(-x) - np.log(p) - scipy.special.betaln(p, q) + ) + if wrtp: + dK = np.log(x) - 1 / p + scipy.special.digamma(p + q) - scipy.special.digamma(p) + else: + dK = np.log1p(-x) + scipy.special.digamma(p + q) - scipy.special.digamma(q) + + for n in range(1, max_iters + 1): + a_n_ = _betainc_a_n(f, p, q, n) + b_n_ = _betainc_b_n(f, p, q, n) + if wrtp: + da_n = _betainc_da_n_dp(f, p, q, n) + db_n = _betainc_db_n_dp(f, p, q, n) + else: + da_n = _betainc_da_n_dq(f, p, q, n) + db_n = _betainc_db_n_dq(f, p, q, n) + + A = a_n_ * Am2 + b_n_ * Am1 + B = a_n_ * Bm2 + b_n_ * Bm1 + dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1 + dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1 + + Am2, Am1 = Am1, A + Bm2, Bm1 = Bm1, B + dAm2, dAm1 = dAm1, dA + dBm2, dBm1 = dBm1, dB + + if n < min_iters - 1: + continue + + F1 = A / B + F2 = (dA - F1 * dB) / B + derivative = K * (F1 * dK + F2) + + errapx = abs(derivative_old - derivative) + d_errapx = errapx / max(err_threshold, abs(derivative)) + derivative_old = derivative + + if d_errapx <= err_threshold: + break + + if n >= max_iters: + return np.nan + + return derivative diff --git a/aesara/tensor/inplace.py b/aesara/tensor/inplace.py index b9879c85ec..b401b38b65 100644 --- a/aesara/tensor/inplace.py +++ b/aesara/tensor/inplace.py @@ -318,6 +318,11 @@ def softplus_inplace(x): """Compute log(1 + exp(x)), also known as softplus or log1pexp""" +@scalar_elemwise +def betainc_inplace(a, b, x): + """Regularized incomplete beta function""" + + @scalar_elemwise def second_inplace(a): """Fill `a` with `b`""" diff --git a/aesara/tensor/math.py b/aesara/tensor/math.py index 0c47e10f34..0f3e1da784 100644 --- a/aesara/tensor/math.py +++ b/aesara/tensor/math.py @@ -1417,6 +1417,11 @@ def softplus(x): """Compute log(1 + exp(x)), also known as softplus or log1pexp""" +@scalar_elemwise +def betainc(a, b, x): + """Regularized incomplete beta function""" + + @scalar_elemwise def real(z): """Return real component of complex-valued tensor `z`""" @@ -2847,6 +2852,7 @@ def power(x, y): "sigmoid", "expit", "softplus", + "betainc", "real", "imag", "angle", diff --git a/tests/scalar/test_math.py b/tests/scalar/test_math.py index 08bd51fdba..7c76a254dd 100644 --- a/tests/scalar/test_math.py +++ b/tests/scalar/test_math.py @@ -1,9 +1,11 @@ import numpy as np +from numpy.testing import assert_allclose, assert_almost_equal import aesara.tensor as aet +from aesara import function from aesara.graph.fg import FunctionGraph from aesara.link.c.basic import CLinker -from aesara.scalar.math import gammainc, gammaincc, gammal, gammau +from aesara.scalar.math import betainc, gammainc, gammaincc, gammal, gammau def test_gammainc_nan(): @@ -44,3 +46,77 @@ def test_gammau_nan(): assert np.isnan(test_func(-1, 1)) assert np.isnan(test_func(1, -1)) assert np.isnan(test_func(-1, -1)) + + +class TestBetaIncGrad: + + # This test replicates the following STAN test: + # https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_reg_inc_beta_test.cpp + def test_stan_grad_combined(self): + a, b, z = aet.scalars("a", "b", "z") + betainc_out = betainc(a, b, z) + betainc_grad = aet.grad(betainc_out, [a, b], null_gradients="return") + f_grad = function([a, b, z], betainc_grad) + + for test_a, test_b, test_z, expected_dda, expected_ddb in ( + (1.0, 1.0, 1.0, 0, np.nan), + (1.0, 1.0, 0.4, -0.36651629, 0.30649537), + ): + assert_allclose( + f_grad(test_a, test_b, test_z), [expected_dda, expected_ddb] + ) + + # This test combines the following STAN tests: + # https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_dda_test.cpp + # https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_ddb_test.cpp + # https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_ddz_test.cpp + def test_stan_grad_partial(self): + a, b, z = aet.scalars("a", "b", "z") + betainc_out = betainc(a, b, z) + betainc_grad = aet.grad(betainc_out, [a, b, z]) + f_grad = function([a, b, z], betainc_grad) + + for test_a, test_b, test_z, expected_dda, expected_ddb, expected_ddz in ( + (1.5, 1.25, 0.001, -0.00028665637, 4.41357328e-05, 0.063300692), + (1.5, 1.25, 0.5, -0.26038693947, 0.29301795, 1.1905416), + (1.5, 1.25, 0.6, -0.23806757, 0.32279575, 1.23341068), + (1.5, 1.25, 0.999, -0.00022264493, 0.0018969609, 0.35587692), + (15000, 1.25, 0.001, 0, 0, 0), + (15000, 1.25, 0.5, 0, 0, 0), + (15000, 1.25, 0.6, 0, 0, 0), + (15000, 1.25, 0.999, -6.59543226e-10, 2.00849793e-06, 0.009898182), + (1.5, 12500, 0.001, -3.93756641e-05, 1.47821755e-09, 0.1848717), + (1.5, 12500, 0.5, 0, 0, 0), + (1.5, 12500, 0.6, 0, 0, 0), + (1.5, 12500, 0.999, 0, 0, 0), + (15000, 12500, 0.001, 0, 0, 0), + (15000, 12500, 0.5, -8.72102443e-53, 9.55282792e-53, 5.01131256e-48), + (15000, 12500, 0.6, -4.085621e-14, -5.5067062e-14, 1.15135267e-71), + (15000, 12500, 0.999, 0, 0, 0), + ): + + assert_almost_equal( + f_grad(test_a, test_b, test_z), + [expected_dda, expected_ddb, expected_ddz], + decimal=4, + ) + + # This test compares against the tabulated values in: + # Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function. + # Journal of Statistical Software, 3(1), 1-20. + def test_boik_robison_cox(self): + a, b, z = aet.scalars("a", "b", "z") + betainc_out = betainc(a, b, z) + betainc_grad = aet.grad(betainc_out, [a, b]) + f_grad = function([a, b, z], betainc_grad) + + for test_a, test_b, test_z, expected_dda, expected_ddb in ( + (1.5, 11.0, 0.001, -4.5720356e-03, 1.1845673e-04), + (1.5, 11.0, 0.5, -2.5501997e-03, 9.0824388e-04), + (1000.0, 1000.0, 0.5, -8.9224793e-03, 8.9224793e-03), + (1000.0, 1000.0, 0.55, -3.6713108e-07, 4.0584118e-07), + ): + assert_almost_equal( + f_grad(test_a, test_b, test_z), + [expected_dda, expected_ddb], + ) diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 1a89ee0406..ed38975561 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -650,3 +650,27 @@ def test_accuracy(self): def test_deprecated_module(): with pytest.warns(DeprecationWarning): import aesara.scalar.basic_scipy # noqa: F401 + + +_good_broadcast_ternary_betainc = dict( + normal=( + rand_ranged(0, 1000, (2, 3)), + rand_ranged(0, 1000, (2, 3)), + rand_ranged(0, 1, (2, 3)), + ), +) + +TestBetaincBroadcast = makeBroadcastTester( + op=aet.betainc, + expected=scipy.special.betainc, + good=_good_broadcast_ternary_betainc, + grad=_good_broadcast_ternary_betainc, +) + +TestBetaincInplaceBroadcast = makeBroadcastTester( + op=inplace.betainc_inplace, + expected=scipy.special.betainc, + good=_good_broadcast_ternary_betainc, + grad=_good_broadcast_ternary_betainc, + inplace=True, +)