Skip to content

Commit

Permalink
adds template parameters for grad_f32 for only calculating gradients …
Browse files Browse the repository at this point in the history
…that are needed. Small cleanup for beta_neg_binomial_lccdf
  • Loading branch information
SteveBronder committed Oct 25, 2024
1 parent 7fc9aab commit 939f03c
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 112 deletions.
123 changes: 68 additions & 55 deletions stan/math/prim/fun/grad_F32.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@ namespace math {
* This power-series representation converges for all gradients
* under the same conditions as the 3F2 function itself.
*
* @tparam T type of arguments and result
* @tparam T1 type of g
* @tparam T1 type of g
* @tparam T1 type of g
* @tparam T1 type of g
* @tparam T1 type of g
* @tparam T1 type of g
* @tparam T1 type of g
* @tparam T1 type of g
* @param[out] g g pointer to array of six values of type T, result.
* @param[in] a1 a1 see generalized hypergeometric function definition.
* @param[in] a2 a2 see generalized hypergeometric function definition.
Expand All @@ -35,84 +42,90 @@ namespace math {
* @param[in] precision precision of the infinite sum
* @param[in] max_steps number of steps to take
*/
template <typename T>
void grad_F32(T* g, const T& a1, const T& a2, const T& a3, const T& b1,
const T& b2, const T& z, const T& precision = 1e-6,
template <bool grad_a1 = true, bool grad_a2 = true, bool grad_a3 = true,
bool grad_b1 = true, bool grad_b2 = true, bool grad_z = true,
typename T1, typename T2, typename T3, typename T4, typename T5,
typename T6, typename T7, typename T8 = double>
void grad_F32(T1* g, const T2& a1, const T3& a2, const T4& a3, const T5& b1,
const T6& b2, const T7& z, const T8& precision = 1e-6,
int max_steps = 1e5) {
check_3F2_converges("grad_F32", a1, a2, a3, b1, b2, z);

using std::exp;
using std::fabs;
using std::log;

for (int i = 0; i < 6; ++i) {
g[i] = 0.0;
}

T log_g_old[6];
T1 log_g_old[6];
for (auto& x : log_g_old) {
x = NEGATIVE_INFTY;
}

T log_t_old = 0.0;
T log_t_new = 0.0;
T1 log_t_old = 0.0;
T1 log_t_new = 0.0;

T log_z = log(z);
T7 log_z = log(z);

double log_t_new_sign = 1.0;
double log_t_old_sign = 1.0;
double log_g_old_sign[6];
T1 log_t_new_sign = 1.0;
T1 log_t_old_sign = 1.0;
T1 log_g_old_sign[6];
for (int i = 0; i < 6; ++i) {
log_g_old_sign[i] = 1.0;
}

std::array<T1, 6> term{0};
for (int k = 0; k <= max_steps; ++k) {
T p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
T1 p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
if (p == 0) {
return;
}

log_t_new += log(fabs(p)) + log_z;
log_t_new_sign = p >= 0.0 ? log_t_new_sign : -log_t_new_sign;
if constexpr (grad_a1) {
term[0] = log_g_old_sign[0] * log_t_old_sign * exp(log_g_old[0] - log_t_old)
+ inv(a1 + k);
log_g_old[0] = log_t_new + log(fabs(term[0]));
log_g_old_sign[0] = term[0] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[0] += log_g_old_sign[0] * exp(log_g_old[0]);
}

if constexpr (grad_a2) {
term[1] = log_g_old_sign[1] * log_t_old_sign * exp(log_g_old[1] - log_t_old)
+ inv(a2 + k);
log_g_old[1] = log_t_new + log(fabs(term[1]));
log_g_old_sign[1] = term[1] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[1] += log_g_old_sign[1] * exp(log_g_old[1]);
}

if constexpr (grad_a3) {
term[2] = log_g_old_sign[2] * log_t_old_sign * exp(log_g_old[2] - log_t_old)
+ inv(a3 + k);
log_g_old[2] = log_t_new + log(fabs(term[2]));
log_g_old_sign[2] = term[2] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[2] += log_g_old_sign[2] * exp(log_g_old[2]);
}

if constexpr (grad_b1) {
term[3] = log_g_old_sign[3] * log_t_old_sign * exp(log_g_old[3] - log_t_old)
- inv(b1 + k);
log_g_old[3] = log_t_new + log(fabs(term[3]));
log_g_old_sign[3] = term[3] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[3] += log_g_old_sign[3] * exp(log_g_old[3]);
}

if constexpr (grad_b2) {
term[4] = log_g_old_sign[4] * log_t_old_sign * exp(log_g_old[4] - log_t_old)
- inv(b2 + k);
log_g_old[4] = log_t_new + log(fabs(term[4]));
log_g_old_sign[4] = term[4] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[4] += log_g_old_sign[4] * exp(log_g_old[4]);
}

// g_old[0] = t_new * (g_old[0] / t_old + 1.0 / (a1 + k));
T term = log_g_old_sign[0] * log_t_old_sign * exp(log_g_old[0] - log_t_old)
+ inv(a1 + k);
log_g_old[0] = log_t_new + log(fabs(term));
log_g_old_sign[0] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[1] = t_new * (g_old[1] / t_old + 1.0 / (a2 + k));
term = log_g_old_sign[1] * log_t_old_sign * exp(log_g_old[1] - log_t_old)
+ inv(a2 + k);
log_g_old[1] = log_t_new + log(fabs(term));
log_g_old_sign[1] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[2] = t_new * (g_old[2] / t_old + 1.0 / (a3 + k));
term = log_g_old_sign[2] * log_t_old_sign * exp(log_g_old[2] - log_t_old)
+ inv(a3 + k);
log_g_old[2] = log_t_new + log(fabs(term));
log_g_old_sign[2] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[3] = t_new * (g_old[3] / t_old - 1.0 / (b1 + k));
term = log_g_old_sign[3] * log_t_old_sign * exp(log_g_old[3] - log_t_old)
- inv(b1 + k);
log_g_old[3] = log_t_new + log(fabs(term));
log_g_old_sign[3] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[4] = t_new * (g_old[4] / t_old - 1.0 / (b2 + k));
term = log_g_old_sign[4] * log_t_old_sign * exp(log_g_old[4] - log_t_old)
- inv(b2 + k);
log_g_old[4] = log_t_new + log(fabs(term));
log_g_old_sign[4] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[5] = t_new * (g_old[5] / t_old + 1.0 / z);
term = log_g_old_sign[5] * log_t_old_sign * exp(log_g_old[5] - log_t_old)
+ inv(z);
log_g_old[5] = log_t_new + log(fabs(term));
log_g_old_sign[5] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

for (int i = 0; i < 6; ++i) {
g[i] += log_g_old_sign[i] * exp(log_g_old[i]);
if constexpr (grad_z) {
term[5] = log_g_old_sign[5] * log_t_old_sign * exp(log_g_old[5] - log_t_old)
+ inv(z);
log_g_old[5] = log_t_new + log(fabs(term[5]));
log_g_old_sign[5] = term[5] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[5] += log_g_old_sign[5] * exp(log_g_old[5]);
}

if (log_t_new <= log(precision)) {
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/grad_pFq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ template <bool calc_a = true, bool calc_b = true, bool calc_z = true,
typename T_Rtn = return_type_t<Ta, Tb, Tz>,
typename Ta_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Ta>>,
typename Tb_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Tb>>>
std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> grad_pFq(const TpFq& pfq_val, const Ta& a,
inline std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> grad_pFq(const TpFq& pfq_val, const Ta& a,
const Tb& b, const Tz& z,
double precision = 1e-14,
int max_steps = 1e6) {
Expand Down
39 changes: 18 additions & 21 deletions stan/math/prim/fun/hypergeometric_3F2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,33 @@ namespace stan {
namespace math {
namespace internal {
template <typename Ta, typename Tb, typename Tz,
typename T_return = return_type_t<Ta, Tb, Tz>,
typename ArrayAT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
typename ArrayBT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
require_all_vector_t<Ta, Tb>* = nullptr,
require_stan_scalar_t<Tz>* = nullptr>
T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
inline return_type_t<Ta, Tb, Tz> hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
double precision = 1e-6,
int max_steps = 1e5) {
ArrayAT a_array = as_array_or_scalar(a);
ArrayBT b_array = append_row(as_array_or_scalar(b), 1.0);
using T_return = return_type_t<Ta, Tb, Tz>;
Eigen::Array<scalar_type_t<Ta>, 3, 1> a_array = as_array_or_scalar(a);
Eigen::Array<scalar_type_t<Tb>, 3, 1> b_array = append_row(as_array_or_scalar(b), 1.0);
check_3F2_converges("hypergeometric_3F2", a_array[0], a_array[1], a_array[2],
b_array[0], b_array[1], z);

T_return t_acc = 1.0;
T_return log_t = 0.0;
T_return log_z = log(fabs(z));
Eigen::ArrayXi a_signs = sign(value_of_rec(a_array));
Eigen::ArrayXi b_signs = sign(value_of_rec(b_array));
plain_type_t<decltype(a_array)> apk = a_array;
plain_type_t<decltype(b_array)> bpk = b_array;
auto log_z = log(fabs(z));
Eigen::Array<int, 3, 1> a_signs = sign(value_of_rec(a_array));
Eigen::Array<int, 3, 1> b_signs = sign(value_of_rec(b_array));
int z_sign = sign(value_of_rec(z));
int t_sign = z_sign * a_signs.prod() * b_signs.prod();

int k = 0;
while (k <= max_steps && log_t >= log(precision)) {
const double log_precision = log(precision);
while (k <= max_steps && log_t >= log_precision) {
// Replace zero values with 1 prior to taking the log so that we accumulate
// 0.0 rather than -inf
const auto& abs_apk = math::fabs((apk == 0).select(1.0, apk));
const auto& abs_bpk = math::fabs((bpk == 0).select(1.0, bpk));
T_return p = sum(log(abs_apk)) - sum(log(abs_bpk));
const auto& abs_apk = math::fabs((a_array == 0).select(1.0, a_array));
const auto& abs_bpk = math::fabs((b_array == 0).select(1.0, b_array));
auto p = sum(log(abs_apk)) - sum(log(abs_bpk));
if (p == NEGATIVE_INFTY) {
return t_acc;
}
Expand All @@ -59,10 +56,10 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
"overflow hypergeometric function did not converge.");
}
k++;
apk.array() += 1.0;
bpk.array() += 1.0;
a_signs = sign(value_of_rec(apk));
b_signs = sign(value_of_rec(bpk));
a_array += 1.0;
b_array += 1.0;
a_signs = sign(value_of_rec(a_array));
b_signs = sign(value_of_rec(b_array));
t_sign = a_signs.prod() * b_signs.prod() * t_sign;
}
if (k == max_steps) {
Expand Down Expand Up @@ -115,7 +112,7 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
template <typename Ta, typename Tb, typename Tz,
require_all_vector_t<Ta, Tb>* = nullptr,
require_stan_scalar_t<Tz>* = nullptr>
auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
inline auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
check_3F2_converges("hypergeometric_3F2", a[0], a[1], a[2], b[0], b[1], z);
// Boost's pFq throws convergence errors in some cases, fallback to naive
// infinite-sum approach (tests pass for these)
Expand Down Expand Up @@ -143,7 +140,7 @@ auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
*/
template <typename Ta, typename Tb, typename Tz,
require_all_stan_scalar_t<Ta, Tb, Tz>* = nullptr>
auto hypergeometric_3F2(const std::initializer_list<Ta>& a,
inline auto hypergeometric_3F2(const std::initializer_list<Ta>& a,
const std::initializer_list<Tb>& b, const Tz& z) {
return hypergeometric_3F2(std::vector<Ta>(a), std::vector<Tb>(b), z);
}
Expand Down
67 changes: 32 additions & 35 deletions stan/math/prim/prob/beta_neg_binomial_lccdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ template <typename T_n, typename T_r, typename T_alpha, typename T_beta>
inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_lccdf(
const T_n& n, const T_r& r, const T_alpha& alpha, const T_beta& beta,
const double precision = 1e-8, const int max_steps = 1e6) {
using std::exp;
using std::log;
using T_partials_return = partials_return_t<T_n, T_r, T_alpha, T_beta>;
using T_r_ref = ref_type_t<T_r>;
using T_alpha_ref = ref_type_t<T_alpha>;
using T_beta_ref = ref_type_t<T_beta>;
static constexpr const char* function = "beta_neg_binomial_lccdf";
check_consistent_sizes(
function, "Failures variable", n, "Number of successes parameter", r,
Expand All @@ -57,67 +51,70 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_lccdf(
return 0;
}

using T_r_ref = ref_type_t<T_r>;
T_r_ref r_ref = r;
using T_alpha_ref = ref_type_t<T_alpha>;
T_alpha_ref alpha_ref = alpha;
using T_beta_ref = ref_type_t<T_beta>;
T_beta_ref beta_ref = beta;
check_positive_finite(function, "Number of successes parameter", r_ref);
check_positive_finite(function, "Prior success parameter", alpha_ref);
check_positive_finite(function, "Prior failure parameter", beta_ref);

T_partials_return log_ccdf(0.0);
auto ops_partials = make_partials_propagator(r_ref, alpha_ref, beta_ref);

scalar_seq_view<T_n> n_vec(n);
scalar_seq_view<T_r_ref> r_vec(r_ref);
scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref);
scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
size_t size_n = stan::math::size(n);
int size_n = stan::math::size(n);
size_t max_size_seq_view = max_size(n, r, alpha, beta);

// Explicit return for extreme values
// The gradients are technically ill-defined, but treated as zero
for (size_t i = 0; i < size_n; i++) {
for (int i = 0; i < size_n; i++) {
if (n_vec.val(i) < 0) {
return ops_partials.build(0.0);
return 0.0;
}
}

using T_partials_return = partials_return_t<T_n, T_r, T_alpha, T_beta>;
T_partials_return log_ccdf(0.0);
auto ops_partials = make_partials_propagator(r_ref, alpha_ref, beta_ref);
for (size_t i = 0; i < max_size_seq_view; i++) {
// Explicit return for extreme values
// The gradients are technically ill-defined, but treated as zero
if (n_vec.val(i) == std::numeric_limits<int>::max()) {
return ops_partials.build(negative_infinity());
}
T_partials_return n_dbl = n_vec.val(i);
T_partials_return r_dbl = r_vec.val(i);
T_partials_return alpha_dbl = alpha_vec.val(i);
T_partials_return beta_dbl = beta_vec.val(i);
T_partials_return b_plus_n = beta_dbl + n_dbl;
T_partials_return r_plus_n = r_dbl + n_dbl;
T_partials_return a_plus_r = alpha_dbl + r_dbl;
T_partials_return one = 1;
T_partials_return precision_t
= precision; // default -6, set -8 to pass all tests

T_partials_return F
= hypergeometric_3F2({one, b_plus_n + 1, r_plus_n + 1},
{n_dbl + 2, a_plus_r + b_plus_n + 1}, one);
T_partials_return C = lgamma(r_plus_n + 1) + lbeta(a_plus_r, b_plus_n + 1)
auto n_dbl = n_vec.val(i);
auto r_dbl = r_vec.val(i);
auto alpha_dbl = alpha_vec.val(i);
auto beta_dbl = beta_vec.val(i);
auto b_plus_n = beta_dbl + n_dbl;
auto r_plus_n = r_dbl + n_dbl;
auto a_plus_r = alpha_dbl + r_dbl;
using a_t = return_type_t<decltype(b_plus_n), decltype(r_plus_n)>;
using b_t = return_type_t<decltype(n_dbl), decltype(a_plus_r), decltype(b_plus_n)>;
auto F
= hypergeometric_3F2(
std::initializer_list<a_t>{1.0, b_plus_n + 1.0, r_plus_n + 1.0},
std::initializer_list<b_t>{n_dbl + 2.0, a_plus_r + b_plus_n + 1.0}, 1.0);
auto C = lgamma(r_plus_n + 1.0) + lbeta(a_plus_r, b_plus_n + 1.0)
- lgamma(r_dbl) - lbeta(alpha_dbl, beta_dbl)
- lgamma(n_dbl + 2);
T_partials_return ccdf = exp(C) * F;
T_partials_return log_ccdf_i = log(ccdf);
log_ccdf += log_ccdf_i;
log_ccdf += C + stan::math::log(F);

if constexpr (!is_constant_all<T_r, T_alpha, T_beta>::value) {
T_partials_return digamma_n_r_alpha_beta
= digamma(a_plus_r + b_plus_n + 1);
auto digamma_n_r_alpha_beta
= digamma(a_plus_r + b_plus_n + 1.0);
T_partials_return dF[6];
grad_F32(dF, one, b_plus_n + 1, r_plus_n + 1, n_dbl + 2,
a_plus_r + b_plus_n + 1, one, precision_t, max_steps);
grad_F32<false, !is_constant<T_beta>::value,
!is_constant_all<T_r>::value, false, true, false>(dF, 1.0,
b_plus_n + 1.0, r_plus_n + 1.0, n_dbl + 2.0,
a_plus_r + b_plus_n + 1.0, 1.0, precision, max_steps);

if constexpr (!is_constant<T_r>::value || !is_constant<T_alpha>::value) {
T_partials_return digamma_r_alpha = digamma(a_plus_r);
auto digamma_r_alpha = digamma(a_plus_r);
if constexpr (!is_constant_all<T_r>::value) {
partials<0>(ops_partials)[i]
+= digamma(r_plus_n + 1)
Expand All @@ -133,7 +130,7 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_lccdf(

if constexpr (!is_constant<T_alpha>::value
|| !is_constant<T_beta>::value) {
T_partials_return digamma_alpha_beta = digamma(alpha_dbl + beta_dbl);
auto digamma_alpha_beta = digamma(alpha_dbl + beta_dbl);
if constexpr (!is_constant<T_alpha>::value) {
partials<1>(ops_partials)[i] += digamma_alpha_beta;
}
Expand Down

0 comments on commit 939f03c

Please sign in to comment.