diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 576ccea..e741abe 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -13,7 +13,7 @@ julia> StatsFuns.xlogx(0) """ function xlogx(x) result = x * log(x) - ifelse(x > zero(x), result, zero(result)) + ifelse(iszero(x), zero(result), result) end """ @@ -23,7 +23,7 @@ Return `x * log(y)` for `y > 0` with correct limit at `x = 0`. """ function xlogy(x, y) result = x * log(y) - ifelse(x > zero(x), result, zero(result)) + ifelse(iszero(x), zero(result), result) end # The following bounds are precomputed versions of the following abstract diff --git a/test/basicfuns.jl b/test/basicfuns.jl index d4f6b9c..bfbd639 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -3,9 +3,11 @@ using StatsFuns, Test @testset "xlogx & xlogy" begin @test iszero(xlogx(0)) @test xlogx(2) ≈ 2.0 * log(2.0) + @test_throws DomainError xlogx(-1) @test iszero(xlogy(0, 1)) @test xlogy(2, 3) ≈ 2.0 * log(3.0) + @test_throws DomainError xlogy(1, -1) end @testset "logistic & logit" begin @@ -94,7 +96,7 @@ end @test logsumexp(arguments) ≡ result end end - + @test isnan(logsubexp(Inf, Inf)) @test isnan(logsubexp(-Inf, -Inf)) @test logsubexp(Inf, 9.0) ≡ Inf