Skip to content

Commit

Permalink
fix issues
Browse files Browse the repository at this point in the history
  • Loading branch information
lingium committed Nov 13, 2024
1 parent fe8055d commit 9594228
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions stan/math/prim/prob/beta_neg_binomial_cdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_cdf(
T_partials_return cdf(1.0);
auto ops_partials = make_partials_propagator(r_ref, alpha_ref, beta_ref);
for (size_t i = 0; i < max_size_seq_view; i++) {
// Explicit return for extreme values
// Explicit results for extreme values
// The gradients are technically ill-defined, but treated as zero
if (n_vec.val(i) == std::numeric_limits<int>::max()) {
return 1.0;
Expand All @@ -101,12 +101,12 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_cdf(
1.0);
auto C = lgamma(r_plus_n + 1.0) + lbeta(a_plus_r, b_plus_n + 1.0)
- lgamma(r_dbl) - lbeta(alpha_dbl, beta_dbl) - lgamma(n_dbl + 2.0);
auto ccdf = stan::math::exp(C) * F;
auto ccdf = stan::math::exp(C + stan::math::log(F));
cdf *= 1.0 - ccdf;

if constexpr (!is_constant_all<T_r, T_alpha, T_beta>::value) {
auto chain_rule_term = -ccdf / (1.0 - ccdf);
auto digamma_n_r_alpha_beta = digamma(a_plus_r + b_plus_n + 1.0);
auto digamma_n_r_alpha_beta = digamma(a_plus_r + b_plus_n + 1);
T_partials_return dF[6];
grad_F32<false, !is_constant<T_beta>::value, !is_constant_all<T_r>::value,
false, true, false>(dF, 1.0, b_plus_n + 1.0, r_plus_n + 1.0,
Expand All @@ -116,15 +116,17 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_cdf(
if constexpr (!is_constant<T_r>::value || !is_constant<T_alpha>::value) {
auto digamma_r_alpha = digamma(a_plus_r);
if constexpr (!is_constant<T_r>::value) {
auto partial_lccdf = digamma(r_plus_n + 1.0)
+ (digamma_r_alpha - digamma_n_r_alpha_beta)
+ (dF[2] + dF[4]) / F - digamma(r_dbl);
partials<0>(ops_partials)[i] += partial_lccdf * chain_rule_term;
partials<0>(ops_partials)[i]
+= (digamma(r_plus_n + 1)
+ (digamma_r_alpha - digamma_n_r_alpha_beta)
+ (dF[2] + dF[4]) / F - digamma(r_dbl))
* chain_rule_term;
}
if constexpr (!is_constant<T_alpha>::value) {
auto partial_lccdf = digamma_r_alpha - digamma_n_r_alpha_beta
+ dF[4] / F - digamma(alpha_dbl);
partials<1>(ops_partials)[i] += partial_lccdf * chain_rule_term;
partials<1>(ops_partials)[i]
+= (digamma_r_alpha - digamma_n_r_alpha_beta + dF[4] / F
- digamma(alpha_dbl))
* chain_rule_term;
}
}

Expand All @@ -135,10 +137,11 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_cdf(
partials<1>(ops_partials)[i] += digamma_alpha_beta * chain_rule_term;
}
if constexpr (!is_constant<T_beta>::value) {
auto partial_lccdf = digamma(b_plus_n + 1.0) - digamma_n_r_alpha_beta
+ (dF[1] + dF[4]) / F
- (digamma(beta_dbl) - digamma_alpha_beta);
partials<2>(ops_partials)[i] += partial_lccdf * chain_rule_term;
partials<2>(ops_partials)[i]
+= (digamma(b_plus_n + 1) - digamma_n_r_alpha_beta
+ (dF[1] + dF[4]) / F
- (digamma(beta_dbl) - digamma_alpha_beta))
* chain_rule_term;
}
}
}
Expand Down

0 comments on commit 9594228

Please sign in to comment.