diff --git a/tests/Test/src/basic/test_basic_logsumexp.birch b/tests/Test/src/basic/test_basic_log_sum_exp.birch similarity index 83% rename from tests/Test/src/basic/test_basic_logsumexp.birch rename to tests/Test/src/basic/test_basic_log_sum_exp.birch index 59d6bddd0..c37659590 100644 --- a/tests/Test/src/basic/test_basic_logsumexp.birch +++ b/tests/Test/src/basic/test_basic_log_sum_exp.birch @@ -3,7 +3,7 @@ * Test log-sum-exp implementations in `log_sum_exp` and * `resample_reduce`. */ -program test_basic_logsumexp() { +program test_basic_log_sum_exp() { /* generate random weights */ w:Real[100]; for n in 1..100 { @@ -13,7 +13,7 @@ program test_basic_logsumexp() { /* compare with common two-pass algorithm */ let y <- log_sum_exp_twopass(w); let ess <- exp(2*y - log_sum_exp_twopass(2*w)); - if !check_ess_logsumexp(w, ess, y) { + if !check_ess_log_sum_exp(w, ess, y) { exit(1); } @@ -21,7 +21,7 @@ program test_basic_logsumexp() { w <- w + 1000.0; y <- log_sum_exp_twopass(w); ess <- exp(2*y - log_sum_exp_twopass(2*w)); - if !check_ess_logsumexp(w, ess, y) { + if !check_ess_log_sum_exp(w, ess, y) { exit(1); } @@ -29,13 +29,13 @@ program test_basic_logsumexp() { w <- [1e-20, log(1e-20)]; y <- 2e-20; ess <- 1.0; - if !check_ess_logsumexp(w, ess, y) { + if !check_ess_log_sum_exp(w, ess, y) { exit(1); } /* check empty input */ x:Real[0]; - if !check_ess_logsumexp(x, 0.0, -inf) { + if !check_ess_log_sum_exp(x, 0.0, -inf) { exit(1); } @@ -58,7 +58,7 @@ program test_basic_logsumexp() { w <- cases[n,1..2]; ess <- cases[n,3]; y <- cases[n,4]; - if !check_ess_logsumexp(w, ess, y) { + if !check_ess_log_sum_exp(w, ess, y) { exit(1); } } @@ -95,9 +95,9 @@ function log_sum_exp_twopass(w:Real[_]) -> Real { * * @return Are the results approximately equal to the expected values? */ -function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real) -> Boolean { +function check_ess_log_sum_exp(w:Real[_], ess_expected:Real, y_expected:Real) -> Boolean { /* roughly half of the significant digits should be correct */ - return check_ess_logsumexp(w, ess_expected, y_expected, 1e-8); + return check_ess_log_sum_exp(w, ess_expected, y_expected, 1e-8); } /* @@ -110,7 +110,7 @@ function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real) -> B * * @return Are the results approximately equal to the expected values? */ -function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real, reltol:Real) -> Boolean { +function check_ess_log_sum_exp(w:Real[_], ess_expected:Real, y_expected:Real, reltol:Real) -> Boolean { let result <- true; let y <- log_sum_exp(w); @@ -139,10 +139,6 @@ function check_ess_logsumexp(w:Real[_], ess_expected:Real, y_expected:Real, relt * @return Are the two scalars approximately equal? */ function approx_equal(x1:Real, x2:Real, reltol:Real) -> Boolean { - // Handle special cases such as `inf`, `-inf` etc. not covered below. - if x1 == x2 { - return true; - } - - return abs(x1 - x2) < reltol*max(abs(x1), abs(x2)); + /* check for equality handles special cases such as `inf`, `-inf` etc. */ + return x1 == x2 || abs(x1 - x2) < reltol*max(abs(x1), abs(x2)); } \ No newline at end of file