Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Aug 16, 2022
1 parent 4971374 commit 17f4b82
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
25 changes: 10 additions & 15 deletions libraries/Standard/src/primitive/resample.birch
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function nan_max(w:Real[_]) -> Real {
*
* @return the logarithm of the sum.
*
* !!! note
* @note
* NaN log weights are treated as though `-inf`.
*
* This uses a numerically stable implementation that avoids over- and
Expand All @@ -51,18 +51,16 @@ function nan_max(w:Real[_]) -> Real {
*/
function log_sum_exp(w:Real[_]) -> Real {
if length(w) > 0 {
// Running maximum of log weights
let mx <- -inf;
let mx <- -inf; // running maximum of log weights
/*
* Running sum of non-maximum weights divided by maximum weight.
* In contrast to Nowozin's implementation the maximum itself
* is not included in the sum.
* This avoids underflow for e.g. w = [1e-20, log(1e-20)].
*/
let r <- 0.0;
wn:Real;
for n in 1..length(w) {
wn <- w[n];
let wn <- w[n];
if wn == inf {
return inf;
} else if wn > mx {
Expand Down Expand Up @@ -282,7 +280,7 @@ function cumulative_weights(w:Real[_]) -> Real[_] {
* @return A pair, the first element of which gives the ESS, the second
* element of which gives the logarithm of the sum of weights.
*
* !!! note
* @note
* NaN log weights are treated as though `-inf`.
*
* This uses a numerically stable implementation that avoids over- and
Expand All @@ -296,8 +294,7 @@ function resample_reduce(w:Real[_]) -> (Real, Real) {
if length(w) == 0 {
return (0.0, -inf);
} else {
// Running maximum of log weights
let mx <- -inf;
let mx <- -inf; // running maximum of log weights
/*
* Running sum of non-maximum weights divided by maximum weight (r),
* and their squares (q).
Expand All @@ -307,31 +304,29 @@ function resample_reduce(w:Real[_]) -> (Real, Real) {
*/
let r <- 0.0;
let q <- 0.0;
wn:Real;
v:Real;
for n in 1..length(w) {
wn <- w[n];
let wn <- w[n];
if wn == inf {
return (1.0, inf);
} else if wn > mx {
v <- exp(mx - wn);
let v <- exp(mx - wn);
r <- (r + 1.0)*v;
q <- (q + 1.0)*v*v;
mx <- wn;
} else if isfinite(wn) {
v <- exp(wn - mx);
let v <- exp(wn - mx);
r <- r + v;
q <- q + v*v;
}
}

// If all weights are `-inf` or `nan`, the result is the same as for empty arrays
/* if all weights are `-inf` or `nan`, the result is the same as for empty arrays */
if mx==-inf {
return (0.0, -inf);
}

let log_sum_weights <- mx + log1p(r);
// The ESS is estimated as (sum w)^2 / (sum w^2)
/* the ESS is estimated as (sum w)^2 / (sum w^2) */
let rp1 <- r + 1.0;
let ess <- rp1*rp1/(q + 1.0);
return (ess, log_sum_weights);
Expand Down
18 changes: 9 additions & 9 deletions tests/Test/src/basic/test_basic_logsumexp.birch
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,42 @@
* `resample_reduce`.
*/
program test_basic_logsumexp() {
// Generate random weights
/* generate random weights */
w:Real[100];
for n in 1..100 {
w[n] <- simulate_gaussian(0.0, 1.0);
}

// Compare with common two-pass algorithm
/* compare with common two-pass algorithm */
let y <- log_sum_exp_twopass(w);
let ess <- exp(2*y - log_sum_exp_twopass(2*w));
if !check_ess_logsumexp(w, ess, y) {
exit(1);
}

// Check overflow
/* check overflow */
w <- w + 1000.0;
y <- log_sum_exp_twopass(w);
ess <- exp(2*y - log_sum_exp_twopass(2*w));
if !check_ess_logsumexp(w, ess, y) {
exit(1);
}

// Check underflow
/* check underflow */
w <- [1e-20, log(1e-20)];
y <- 2e-20;
ess <- 1.0;
if !check_ess_logsumexp(w, ess, y) {
exit(1);
}

// Check empty input
/* check empty input */
x:Real[0];
if !check_ess_logsumexp(x, 0.0, -inf) {
exit(1);
}

// Special cases involving -inf, inf, and nan.
/* special cases involving -inf, inf, and nan */
let cases <- [[-inf, -inf, 0.0, -inf],
[-inf, nan, 0.0, -inf],
[nan, -inf, 0.0, -inf],
Expand Down Expand Up @@ -71,7 +71,7 @@ program test_basic_logsumexp() {
*
* @return the logarithm of the sum.
*
* !!! note
* @note
* This implementation uses the common two-pass algorithm
* that avoids overflow.
*/
Expand All @@ -96,7 +96,7 @@ function log_sum_exp_twopass(w:Real[_]) -> Real {
* @return Are the results approximately equal to the expected values?
*/
function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real) -> Boolean {
// Roughly half of the significant digits should be correct.
/* roughly half of the significant digits should be correct */
return check_ess_logsumexp(w, ess_expected, y_expected, 1e-8);
}

Expand Down Expand Up @@ -140,7 +140,7 @@ function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real, relt
*/
function approx_equal(x1:Real, x2:Real, reltol:Real) -> Boolean {
// Handle special cases such as `inf`, `-inf` etc. not covered below.
if x1==x2 {
if x1 == x2 {
return true;
}

Expand Down

0 comments on commit 17f4b82

Please sign in to comment.