From 17f4b82ae0060dbac347c34bab220110d402bffd Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 16 Aug 2022 10:44:40 +0200 Subject: [PATCH] Apply suggestions from code review --- .../Standard/src/primitive/resample.birch | 25 ++++++++----------- .../Test/src/basic/test_basic_logsumexp.birch | 18 ++++++------- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/libraries/Standard/src/primitive/resample.birch b/libraries/Standard/src/primitive/resample.birch index 9398d64ae..551edccb4 100644 --- a/libraries/Standard/src/primitive/resample.birch +++ b/libraries/Standard/src/primitive/resample.birch @@ -39,7 +39,7 @@ function nan_max(w:Real[_]) -> Real { * * @return the logarithm of the sum. * - * !!! note + * @note * NaN log weights are treated as though `-inf`. * * This uses a numerically stable implementation that avoids over- and @@ -51,8 +51,7 @@ function nan_max(w:Real[_]) -> Real { */ function log_sum_exp(w:Real[_]) -> Real { if length(w) > 0 { - // Running maximum of log weights - let mx <- -inf; + let mx <- -inf; // running maximum of log weights /* * Running sum of non-maximum weights divided by maximum weight. * In contrast to Nowozin's implementation the maximum itself @@ -60,9 +59,8 @@ function log_sum_exp(w:Real[_]) -> Real { * This avoids underflow for e.g. w = [1e-20, log(1e-20)]. */ let r <- 0.0; - wn:Real; for n in 1..length(w) { - wn <- w[n]; + let wn <- w[n]; if wn == inf { return inf; } else if wn > mx { @@ -282,7 +280,7 @@ function cumulative_weights(w:Real[_]) -> Real[_] { * @return A pair, the first element of which gives the ESS, the second * element of which gives the logarithm of the sum of weights. * - * !!! note + * @note * NaN log weights are treated as though `-inf`. * * This uses a numerically stable implementation that avoids over- and @@ -296,8 +294,7 @@ function resample_reduce(w:Real[_]) -> (Real, Real) { if length(w) == 0 { return (0.0, -inf); } else { - // Running maximum of log weights - let mx <- -inf; + let mx <- -inf; // running maximum of log weights /* * Running sum of non-maximum weights divided by maximum weight (r), * and their squares (q). @@ -307,31 +304,29 @@ function resample_reduce(w:Real[_]) -> (Real, Real) { */ let r <- 0.0; let q <- 0.0; - wn:Real; - v:Real; for n in 1..length(w) { - wn <- w[n]; + let wn <- w[n]; if wn == inf { return (1.0, inf); } else if wn > mx { - v <- exp(mx - wn); + let v <- exp(mx - wn); r <- (r + 1.0)*v; q <- (q + 1.0)*v*v; mx <- wn; } else if isfinite(wn) { - v <- exp(wn - mx); + let v <- exp(wn - mx); r <- r + v; q <- q + v*v; } } - // If all weights are `-inf` or `nan`, the result is the same as for empty arrays + /* if all weights are `-inf` or `nan`, the result is the same as for empty arrays */ if mx==-inf { return (0.0, -inf); } let log_sum_weights <- mx + log1p(r); - // The ESS is estimated as (sum w)^2 / (sum w^2) + /* the ESS is estimated as (sum w)^2 / (sum w^2) */ let rp1 <- r + 1.0; let ess <- rp1*rp1/(q + 1.0); return (ess, log_sum_weights); diff --git a/tests/Test/src/basic/test_basic_logsumexp.birch b/tests/Test/src/basic/test_basic_logsumexp.birch index 21c4d5320..59d6bddd0 100644 --- a/tests/Test/src/basic/test_basic_logsumexp.birch +++ b/tests/Test/src/basic/test_basic_logsumexp.birch @@ -4,20 +4,20 @@ * `resample_reduce`. */ program test_basic_logsumexp() { - // Generate random weights + /* generate random weights */ w:Real[100]; for n in 1..100 { w[n] <- simulate_gaussian(0.0, 1.0); } - // Compare with common two-pass algorithm + /* compare with common two-pass algorithm */ let y <- log_sum_exp_twopass(w); let ess <- exp(2*y - log_sum_exp_twopass(2*w)); if !check_ess_logsumexp(w, ess, y) { exit(1); } - // Check overflow + /* check overflow */ w <- w + 1000.0; y <- log_sum_exp_twopass(w); ess <- exp(2*y - log_sum_exp_twopass(2*w)); @@ -25,7 +25,7 @@ program test_basic_logsumexp() { exit(1); } - // Check underflow + /* check underflow */ w <- [1e-20, log(1e-20)]; y <- 2e-20; ess <- 1.0; @@ -33,13 +33,13 @@ program test_basic_logsumexp() { exit(1); } - // Check empty input + /* check empty input */ x:Real[0]; if !check_ess_logsumexp(x, 0.0, -inf) { exit(1); } - // Special cases involving -inf, inf, and nan. + /* special cases involving -inf, inf, and nan */ let cases <- [[-inf, -inf, 0.0, -inf], [-inf, nan, 0.0, -inf], [nan, -inf, 0.0, -inf], @@ -71,7 +71,7 @@ program test_basic_logsumexp() { * * @return the logarithm of the sum. * - * !!! note + * @note * This implementation uses the common two-pass algorithm * that avoids overflow. */ @@ -96,7 +96,7 @@ function log_sum_exp_twopass(w:Real[_]) -> Real { * @return Are the results approximately equal to the expected values? */ function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real) -> Boolean { - // Roughly half of the significant digits should be correct. + /* roughly half of the significant digits should be correct */ return check_ess_logsumexp(w, ess_expected, y_expected, 1e-8); } @@ -140,7 +140,7 @@ function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real, relt */ function approx_equal(x1:Real, x2:Real, reltol:Real) -> Boolean { // Handle special cases such as `inf`, `-inf` etc. not covered below. - if x1==x2 { + if x1 == x2 { return true; }