Skip to content

Commit

Permalink
Rename to log_sum_exp and simplify approx_equal
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Aug 16, 2022
1 parent 17f4b82 commit e37aec5
Showing 1 changed file with 11 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Test log-sum-exp implementations in `log_sum_exp` and
* `resample_reduce`.
*/
program test_basic_logsumexp() {
program test_basic_log_sum_exp() {
/* generate random weights */
w:Real[100];
for n in 1..100 {
Expand All @@ -13,29 +13,29 @@ program test_basic_logsumexp() {
/* compare with common two-pass algorithm */
let y <- log_sum_exp_twopass(w);
let ess <- exp(2*y - log_sum_exp_twopass(2*w));
if !check_ess_logsumexp(w, ess, y) {
if !check_ess_log_sum_exp(w, ess, y) {
exit(1);
}

/* check overflow */
w <- w + 1000.0;
y <- log_sum_exp_twopass(w);
ess <- exp(2*y - log_sum_exp_twopass(2*w));
if !check_ess_logsumexp(w, ess, y) {
if !check_ess_log_sum_exp(w, ess, y) {
exit(1);
}

/* check underflow */
w <- [1e-20, log(1e-20)];
y <- 2e-20;
ess <- 1.0;
if !check_ess_logsumexp(w, ess, y) {
if !check_ess_log_sum_exp(w, ess, y) {
exit(1);
}

/* check empty input */
x:Real[0];
if !check_ess_logsumexp(x, 0.0, -inf) {
if !check_ess_log_sum_exp(x, 0.0, -inf) {
exit(1);
}

Expand All @@ -58,7 +58,7 @@ program test_basic_logsumexp() {
w <- cases[n,1..2];
ess <- cases[n,3];
y <- cases[n,4];
if !check_ess_logsumexp(w, ess, y) {
if !check_ess_log_sum_exp(w, ess, y) {
exit(1);
}
}
Expand Down Expand Up @@ -95,9 +95,9 @@ function log_sum_exp_twopass(w:Real[_]) -> Real {
*
* @return Are the results approximately equal to the expected values?
*/
function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real) -> Boolean {
function check_ess_log_sum_exp(w:Real[_], ess_expected:Real, y_expected:Real) -> Boolean {
/* roughly half of the significant digits should be correct */
return check_ess_logsumexp(w, ess_expected, y_expected, 1e-8);
return check_ess_log_sum_exp(w, ess_expected, y_expected, 1e-8);
}

/*
Expand All @@ -110,7 +110,7 @@ function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real) -> B
*
* @return Are the results approximately equal to the expected values?
*/
function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real, reltol:Real) -> Boolean {
function check_ess_log_sum_exp(w:Real[_], ess_expected:Real, y_expected:Real, reltol:Real) -> Boolean {
let result <- true;

let y <- log_sum_exp(w);
Expand Down Expand Up @@ -139,10 +139,6 @@ function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real, relt
* @return Are the two scalars approximately equal?
*/
function approx_equal(x1:Real, x2:Real, reltol:Real) -> Boolean {
// Handle special cases such as `inf`, `-inf` etc. not covered below.
if x1 == x2 {
return true;
}

return abs(x1 - x2) < reltol*max(abs(x1), abs(x2));
/* check for equality handles special cases such as `inf`, `-inf` etc. */
return x1 == x2 || abs(x1 - x2) < reltol*max(abs(x1), abs(x2));
}

0 comments on commit e37aec5

Please sign in to comment.