Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use numerically stable one-pass algorithm for log-sum-exp #18

Merged
merged 24 commits into from
Aug 21, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 79 additions & 18 deletions libraries/Standard/src/primitive/resample.birch
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,38 @@ function nan_max(w:Real[_]) -> Real {
*
* !!! note
devmotion marked this conversation as resolved.
Show resolved Hide resolved
* NaN log weights are treated as though `-inf`.
*
* This uses a numerically stable implementation that avoids over- and
* underflow, using a single pass over the data.
* It is based on:
*
* S. Nowozin (2016). Streaming Log-sum-exp Computation.
* http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html
*/
function log_sum_exp(w:Real[_]) -> Real {
if length(w) > 0 {
let mx <- nan_max(w);
let r <- transform_reduce(w, 0.0, \(x:Real, y:Real) -> { return x + y; },
\(x:Real) -> { return nan_exp(x - mx); });
return mx + log(r);
// Running maximum of log weights
let mx <- -inf;
devmotion marked this conversation as resolved.
Show resolved Hide resolved
/*
* Running sum of non-maximum weights divided by maximum weight.
* In contrast to Nowozin's implementation the maximum itself
* is not included in the sum.
* This avoids underflow for e.g. w = [1e-20, log(1e-20)].
*/
let r <- 0.0;
wn:Real;
devmotion marked this conversation as resolved.
Show resolved Hide resolved
for n in 1..length(w) {
wn <- w[n];
devmotion marked this conversation as resolved.
Show resolved Hide resolved
if wn == inf {
return inf;
} else if wn > mx {
r <- (r + 1.0)*exp(mx - wn);
mx <- wn;
} else if isfinite(wn) {
r <- r + exp(wn - mx);
}
}
return mx + log1p(r);
} else {
return -inf;
}
Expand All @@ -67,10 +92,7 @@ function norm_exp(w:Real[_]) -> Real[_] {
if length(w) == 0 {
return w;
} else {
let mx <- nan_max(w);
let r <- transform_reduce(w, 0.0, \(x:Real, y:Real) -> { return x + y; },
\(x:Real) -> { return nan_exp(x - mx); });
let W <- mx + log(r);
let W <- log_sum_exp(w);
return transform(w, \(x:Real) -> { return nan_exp(x - W); });
}
}
Expand Down Expand Up @@ -259,20 +281,59 @@ function cumulative_weights(w:Real[_]) -> Real[_] {
*
* @return A pair, the first element of which gives the ESS, the second
* element of which gives the logarithm of the sum of weights.
*
* !!! note
devmotion marked this conversation as resolved.
Show resolved Hide resolved
* NaN log weights are treated as though `-inf`.
*
* This uses a numerically stable implementation that avoids over- and
* underflow, using a single pass over the data.
* It is based on:
*
* S. Nowozin (2016). Streaming Log-sum-exp Computation.
* http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html
*/
function resample_reduce(w:Real[_]) -> (Real, Real) {
if length(w) == 0 {
return (0.0, 0.0);
return (0.0, -inf);
} else {
let N <- length(w);
let W <- 0.0;
let W2 <- 0.0;
let mx <- nan_max(w);
for n in 1..N {
let v <- nan_exp(w[n] - mx);
W <- W + v;
W2 <- W2 + v*v;
// Running maximum of log weights
let mx <- -inf;
devmotion marked this conversation as resolved.
Show resolved Hide resolved
/*
* Running sum of non-maximum weights divided by maximum weight (r),
* and their squares (q).
* In contrast to Nowozin's implementation the maximum itself
* is not included in the sum.
* This avoids underflow for e.g. w = [1e-20, log(1e-20)].
*/
let r <- 0.0;
let q <- 0.0;
wn:Real;
devmotion marked this conversation as resolved.
Show resolved Hide resolved
v:Real;
devmotion marked this conversation as resolved.
Show resolved Hide resolved
for n in 1..length(w) {
wn <- w[n];
devmotion marked this conversation as resolved.
Show resolved Hide resolved
if wn == inf {
return (1.0, inf);
} else if wn > mx {
v <- exp(mx - wn);
devmotion marked this conversation as resolved.
Show resolved Hide resolved
r <- (r + 1.0)*v;
q <- (q + 1.0)*v*v;
mx <- wn;
} else if isfinite(wn) {
v <- exp(wn - mx);
devmotion marked this conversation as resolved.
Show resolved Hide resolved
r <- r + v;
q <- q + v*v;
}
}
return (W*W/W2, log(W) + mx);

// If all weights are `-inf` or `nan`, the result is the same as for empty arrays
devmotion marked this conversation as resolved.
Show resolved Hide resolved
if mx==-inf {
return (0.0, -inf);
}

let log_sum_weights <- mx + log1p(r);
// The ESS is estimated as (sum w)^2 / (sum w^2)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
let rp1 <- r + 1.0;
let ess <- rp1*rp1/(q + 1.0);
return (ess, log_sum_weights);
}
}
148 changes: 148 additions & 0 deletions tests/Test/src/basic/test_basic_logsumexp.birch
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@

/*
* Test log-sum-exp implementations in `log_sum_exp` and
* `resample_reduce`.
*/
program test_basic_logsumexp() {
// Generate random weights
devmotion marked this conversation as resolved.
Show resolved Hide resolved
w:Real[100];
for n in 1..100 {
w[n] <- simulate_gaussian(0.0, 1.0);
}

// Compare with common two-pass algorithm
devmotion marked this conversation as resolved.
Show resolved Hide resolved
let y <- log_sum_exp_twopass(w);
let ess <- exp(2*y - log_sum_exp_twopass(2*w));
if !check_ess_logsumexp(w, ess, y) {
exit(1);
}

// Check overflow
devmotion marked this conversation as resolved.
Show resolved Hide resolved
w <- w + 1000.0;
y <- log_sum_exp_twopass(w);
ess <- exp(2*y - log_sum_exp_twopass(2*w));
if !check_ess_logsumexp(w, ess, y) {
exit(1);
}

// Check underflow
devmotion marked this conversation as resolved.
Show resolved Hide resolved
w <- [1e-20, log(1e-20)];
y <- 2e-20;
ess <- 1.0;
if !check_ess_logsumexp(w, ess, y) {
exit(1);
}

// Check empty input
devmotion marked this conversation as resolved.
Show resolved Hide resolved
x:Real[0];
if !check_ess_logsumexp(x, 0.0, -inf) {
exit(1);
}

// Special cases involving -inf, inf, and nan.
devmotion marked this conversation as resolved.
Show resolved Hide resolved
let cases <- [[-inf, -inf, 0.0, -inf],
[-inf, nan, 0.0, -inf],
[nan, -inf, 0.0, -inf],
[-inf, 42.0, 1.0, 42.0],
[nan, 42.0, 1.0, 42.0],
[42.0, -inf, 1.0, 42.0],
[42.0, nan, 1.0, 42.0],
[-inf, inf, 1.0, inf],
[nan, inf, 1.0, inf],
[42.0, inf, 1.0, inf],
[inf, -inf, 1.0, inf],
[inf, nan, 1.0, inf],
[inf, 42.0, 1.0, inf],
[inf, inf, 1.0, inf]];
for n in 1..length(cases) {
w <- cases[n,1..2];
ess <- cases[n,3];
y <- cases[n,4];
if !check_ess_logsumexp(w, ess, y) {
exit(1);
}
}
}

/*
* Exponentiate and sum a log weight vector.
*
* @param w Log weights.
*
* @return the logarithm of the sum.
*
* !!! note
devmotion marked this conversation as resolved.
Show resolved Hide resolved
* This implementation uses the common two-pass algorithm
* that avoids overflow.
*/
function log_sum_exp_twopass(w:Real[_]) -> Real {
if length(w) > 0 {
let mx <- nan_max(w);
let r <- transform_reduce(w, 0.0, \(x:Real, y:Real) -> { return x + y; },
\(x:Real) -> { return nan_exp(x - mx); });
return mx + log(r);
} else {
return -inf;
}
}

/*
* Check output of `log_sum_exp` and `resample_reduce`.
*
* @param w Log weights.
* @param ess_expected Expected ESS estimate.
* @param y_expected Expected log sum of weights.
*
* @return Are the results approximately equal to the expected values?
*/
function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real) -> Boolean {
// Roughly half of the significant digits should be correct.
devmotion marked this conversation as resolved.
Show resolved Hide resolved
return check_ess_logsumexp(w, ess_expected, y_expected, 1e-8);
}

/*
* Check output of `log_sum_exp` and `resample_reduce`.
*
* @param w Log weights.
* @param ess_expected Expected ESS estimate.
* @param y_expected Expected log sum of weights.
* @param reltol Relative tolerace.
*
* @return Are the results approximately equal to the expected values?
*/
function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real, reltol:Real) -> Boolean {
let result <- true;

let y <- log_sum_exp(w);
if !approx_equal(y, y_expected, reltol) {
stderr.print("log_sum_exp(" + w + ") = " + y + " ≉ " + y_expected + "(reltol = " + reltol + ")\n");
result <- false;
}

ess:Real;
(ess, y) <- resample_reduce(w);
if !approx_equal(ess, ess_expected, reltol) || !approx_equal(y, y_expected, reltol) {
stderr.print("resample_reduce(" + w + ") = (" + ess + ", " + y + ") ≉ (" + ess_expected + ", " + y_expected + ") (reltol = " + reltol + ")\n");
result <- false;
}

return result;
}

/*
* Check if two scalars are approximately equal.
*
* @param x1 First scalar.
* @param x2 Second scalar.
* @param reltol Relative tolerace.
*
* @return Are the two scalars approximately equal?
*/
function approx_equal(x1:Real, x2:Real, reltol:Real) -> Boolean {
// Handle special cases such as `inf`, `-inf` etc. not covered below.
if x1==x2 {
devmotion marked this conversation as resolved.
Show resolved Hide resolved
return true;
}

return abs(x1 - x2) < reltol*max(abs(x1), abs(x2));
}