From 53e888849e6812dc03bbfe677d8bc9833b1aa978 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 28 Sep 2021 14:01:46 +0200 Subject: [PATCH] Allow `Real` arguments in R functions and use Julia implementations only (#125) * Allow arguments of type `Float32` and `Float16` in R functions * Add tests * Bump version * Allow `Real` if no Julia fallback exists and `Int`s, `UInt`s, and `Rational`s otherwise * Do not use R implementation when Julia implementation is available * Add more tests --- Project.toml | 4 +- README.md | 1 + src/StatsFuns.jl | 1 + src/chainrules.jl | 6 +- src/distrs/beta.jl | 19 +++--- src/distrs/binom.jl | 17 ++++-- src/distrs/chisq.jl | 23 ++++--- src/distrs/fdist.jl | 20 ++++--- src/distrs/gamma.jl | 19 +++--- src/distrs/hyper.jl | 3 +- src/distrs/nbeta.jl | 3 +- src/distrs/nbinom.jl | 3 +- src/distrs/nchisq.jl | 3 +- src/distrs/nfdist.jl | 3 +- src/distrs/ntdist.jl | 3 +- src/distrs/pois.jl | 26 +++----- src/distrs/srdist.jl | 3 +- src/distrs/tdist.jl | 17 +++--- src/rmath.jl | 65 +++++++++++++------- test/chainrules.jl | 17 +++--- test/generic.jl | 13 ++-- test/rmath.jl | 140 ++++++++++++++++++++++++++++++++++++++----- 22 files changed, 279 insertions(+), 130 deletions(-) diff --git a/Project.toml b/Project.toml index 5a71cfc..32f5bd1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StatsFuns" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "0.9.10" +version = "0.9.11" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -13,7 +13,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] ChainRulesCore = "1" IrrationalConstants = "0.1" -LogExpFunctions = "0.3" +LogExpFunctions = "0.3.2" Reexport = "1" Rmath = "0.4, 0.5, 0.6, 0.7" SpecialFunctions = "0.8, 0.9, 0.10, 1.0" diff --git a/README.md b/README.md index 555965a..65e40d7 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ log4π, # log(4π) # basicfuns xlogx, # x * log(x), or 0 when x is zero xlogy, # x * log(y), or 0 when x is zero +xlog1py, # x * log(1 + y) for x > 0, or 0 when x == 0 logistic, # 1 / (1 + exp(-x)) logit, # log(x / (1 - x)) log1psq, # log(1 + x^2) diff --git a/src/StatsFuns.jl b/src/StatsFuns.jl index 2ff3414..cc4e0e0 100644 --- a/src/StatsFuns.jl +++ b/src/StatsFuns.jl @@ -35,6 +35,7 @@ import ChainRulesCore @reexport using LogExpFunctions: xlogx, # x * log(x) for x > 0, or 0 when x == 0 xlogy, # x * log(y) for x > 0, or 0 when x == 0 + xlog1py, # x * log(1 + y) for x > 0, or 0 when x == 0 logistic, # 1 / (1 + exp(-x)) logit, # log(x / (1 - x)) log1psq, # log(1 + x^2) diff --git a/src/chainrules.jl b/src/chainrules.jl index f405c21..7f6040f 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -12,9 +12,9 @@ ChainRulesCore.@scalar_rule( binomlogpdf(n::Real, p::Real, k::Real), @setup(z = digamma(n - k + 1)), ( - digamma(n + 2) - z + log1p(-p) - 1 / (1 + n), + ChainRulesCore.NoTangent(), (k / p - n) / (1 - p), - z - digamma(k + 1) + logit(p), + ChainRulesCore.NoTangent(), ), ) @@ -59,7 +59,7 @@ ChainRulesCore.@scalar_rule( ChainRulesCore.@scalar_rule( poislogpdf(λ::Number, x::Number), - ((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, log(λ) - digamma(x + 1)), + ((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, ChainRulesCore.NoTangent()), ) ChainRulesCore.@scalar_rule( diff --git a/src/distrs/beta.jl b/src/distrs/beta.jl index 3d1bfae..6328c4e 100644 --- a/src/distrs/beta.jl +++ b/src/distrs/beta.jl @@ -1,8 +1,8 @@ # functions related to beta distributions -import .RFunctions: - betapdf, - betalogpdf, +# R implementations +# For pdf and logpdf we use the Julia implementation +using .RFunctions: betacdf, betaccdf, betalogcdf, @@ -12,8 +12,13 @@ import .RFunctions: betainvlogcdf, betainvlogccdf -# pdf for numbers with generic types -betapdf(α::Real, β::Real, x::Number) = x^(α - 1) * (1 - x)^(β - 1) / beta(α, β) +# Julia implementations +betapdf(α::Real, β::Real, x::Real) = exp(betalogpdf(α, β, x)) -# logpdf for numbers with generic types -betalogpdf(α::Real, β::Real, x::Number) = (α - 1) * log(x) + (β - 1) * log1p(-x) - logbeta(α, β) +betalogpdf(α::Real, β::Real, x::Real) = betalogpdf(promote(α, β, x)...) +function betalogpdf(α::T, β::T, x::T) where {T<:Real} + # we ensure that `log(x)` and `log1p(-x)` do not error + y = clamp(x, 0, 1) + val = xlogy(α - 1, y) + xlog1py(β - 1, -y) - logbeta(α, β) + return x < 0 || x > 1 ? oftype(val, -Inf) : val +end diff --git a/src/distrs/binom.jl b/src/distrs/binom.jl index e1237f1..09b1c2a 100644 --- a/src/distrs/binom.jl +++ b/src/distrs/binom.jl @@ -1,8 +1,8 @@ # functions related to binomial distribution -import .RFunctions: - binompdf, - binomlogpdf, +# R implementations +# For pdf and logpdf we use the Julia implementation +using .RFunctions: binomcdf, binomccdf, binomlogcdf, @@ -12,8 +12,13 @@ import .RFunctions: binominvlogcdf, binominvlogccdf -# pdf for numbers with generic types + +# Julia implementations binompdf(n::Real, p::Real, k::Real) = exp(binomlogpdf(n, p, k)) -# logpdf for numbers with generic types -binomlogpdf(n::Real, p::Real, k::Real) = -log1p(n) - logbeta(n - k + 1, k + 1) + k * log(p) + (n - k) * log1p(-p) +binomlogpdf(n::Real, p::Real, k::Real) = binomlogpdf(promote(n, p, k)...) +function binomlogpdf(n::T, p::T, k::T) where {T<:Real} + m = clamp(k, 0, n) + val = betalogpdf(m + 1, n - m + 1, p) - log(n + 1) + return 0 <= k <= n && isinteger(k) ? val : oftype(val, -Inf) +end diff --git a/src/distrs/chisq.jl b/src/distrs/chisq.jl index 25a069d..467fa08 100644 --- a/src/distrs/chisq.jl +++ b/src/distrs/chisq.jl @@ -1,8 +1,8 @@ # functions related to chi-square distribution -import .RFunctions: - chisqpdf, - chisqlogpdf, +# R implementations +# For pdf and logpdf we use the Julia implementation +using .RFunctions: chisqcdf, chisqccdf, chisqlogcdf, @@ -12,14 +12,11 @@ import .RFunctions: chisqinvlogcdf, chisqinvlogccdf -# pdf for numbers with generic types -function chisqpdf(k::Real, x::Number) - hk = k / 2 # half k - 1 / (2^(hk) * gamma(hk)) * x^(hk - 1) * exp(-x / 2) -end +# Julia implementations +# promotion ensures that we do forward e.g. `chisqpdf(::Int, ::Float32)` to +# `gammapdf(::Float32, ::Int, ::Float32)` but not `gammapdf(::Float64, ::Int, ::Float32)` +chisqpdf(k::Real, x::Real) = chisqpdf(promote(k, x)...) +chisqpdf(k::T, x::T) where {T<:Real} = gammapdf(k / 2, 2, x) -# logpdf for numbers with generic types -function chisqlogpdf(k::Real, x::Number) - hk = k / 2 # half k - -hk * logtwo - loggamma(hk) + (hk - 1) * log(x) - x / 2 -end +chisqlogpdf(k::Real, x::Real) = chisqlogpdf(promote(k, x)...) +chisqlogpdf(k::T, x::T) where {T<:Real} = gammalogpdf(k / 2, 2, x) diff --git a/src/distrs/fdist.jl b/src/distrs/fdist.jl index b3041e3..1f11c16 100644 --- a/src/distrs/fdist.jl +++ b/src/distrs/fdist.jl @@ -1,8 +1,8 @@ # functions related to F distribution -import .RFunctions: - fdistpdf, - fdistlogpdf, +# R implementations +# For pdf and logpdf we use the Julia implementation +using .RFunctions: fdistcdf, fdistccdf, fdistlogcdf, @@ -12,8 +12,14 @@ import .RFunctions: fdistinvlogcdf, fdistinvlogccdf -# pdf for numbers with generic types -fdistpdf(ν1::Real, ν2::Real, x::Number) = sqrt((ν1 * x)^ν1 * ν2^ν2 / (ν1 * x + ν2)^(ν1 + ν2)) / (x * beta(ν1 / 2, ν2 / 2)) +# Julia implementations +fdistpdf(ν1::Real, ν2::Real, x::Real) = exp(fdistlogpdf(ν1, ν2, x)) -# logpdf for numbers with generic types -fdistlogpdf(ν1::Real, ν2::Real, x::Number) = (ν1 * log(ν1 * x) + ν2 * log(ν2) - (ν1 + ν2) * log(ν1 * x + ν2)) / 2 - log(x) - logbeta(ν1 / 2, ν2 / 2) +fdistlogpdf(ν1::Real, ν2::Real, x::Real) = fdistlogpdf(promote(ν1, ν2, x)...) +function fdistlogpdf(ν1::T, ν2::T, x::T) where {T<:Real} + # we ensure that `log(x)` does not error if `x < 0` + ν1ν2 = ν1 / ν2 + y = max(x, 0) + val = (xlogy(ν1, ν1ν2) + xlogy(ν1 - 2, y) - xlogy(ν1 + ν2, 1 + ν1ν2 * y)) / 2 - logbeta(ν1 / 2, ν2 / 2) + return x < 0 ? oftype(val, -Inf) : val +end diff --git a/src/distrs/gamma.jl b/src/distrs/gamma.jl index 26dbef6..15b3975 100644 --- a/src/distrs/gamma.jl +++ b/src/distrs/gamma.jl @@ -1,8 +1,8 @@ # functions related to gamma distribution -import .RFunctions: - gammapdf, - gammalogpdf, +# R implementations +# For pdf and logpdf we use the Julia implementation +using .RFunctions: gammacdf, gammaccdf, gammalogcdf, @@ -12,8 +12,13 @@ import .RFunctions: gammainvlogcdf, gammainvlogccdf -# pdf for numbers with generic types -gammapdf(k::Real, θ::Real, x::Number) = 1 / (gamma(k) * θ^k) * x^(k - 1) * exp(-x / θ) +# Julia implementations +gammapdf(k::Real, θ::Real, x::Real) = exp(gammalogpdf(k, θ, x)) -# logpdf for numbers with generic types -gammalogpdf(k::Real, θ::Real, x::Number) = -loggamma(k) - k * log(θ) + (k - 1) * log(x) - x / θ +gammalogpdf(k::Real, θ::Real, x::Real) = gammalogpdf(promote(k, θ, x)...) +function gammalogpdf(k::T, θ::T, x::T) where {T<:Real} + # we ensure that `log(x)` does not error if `x < 0` + xθ = max(x, 0) / θ + val = -loggamma(k) + xlogy(k - 1, xθ) - log(θ) - xθ + return x < 0 ? oftype(val, -Inf) : val +end diff --git a/src/distrs/hyper.jl b/src/distrs/hyper.jl index 5d62bd6..6c2869f 100644 --- a/src/distrs/hyper.jl +++ b/src/distrs/hyper.jl @@ -1,6 +1,7 @@ # functions related to hyper-geometric distribution -import .RFunctions: +# R implementations +using .RFunctions: hyperpdf, hyperlogpdf, hypercdf, diff --git a/src/distrs/nbeta.jl b/src/distrs/nbeta.jl index 9007f7e..33c59cf 100644 --- a/src/distrs/nbeta.jl +++ b/src/distrs/nbeta.jl @@ -1,6 +1,7 @@ # functions related to noncentral beta distribution -import .RFunctions: +# R implementations +using .RFunctions: nbetapdf, nbetalogpdf, nbetacdf, diff --git a/src/distrs/nbinom.jl b/src/distrs/nbinom.jl index 6ffae35..5f80b1e 100644 --- a/src/distrs/nbinom.jl +++ b/src/distrs/nbinom.jl @@ -1,6 +1,7 @@ # functions related to negative binomial distribution -import .RFunctions: +# R implementations +using .RFunctions: nbinompdf, nbinomlogpdf, nbinomcdf, diff --git a/src/distrs/nchisq.jl b/src/distrs/nchisq.jl index 942b7d5..9382d04 100644 --- a/src/distrs/nchisq.jl +++ b/src/distrs/nchisq.jl @@ -1,6 +1,7 @@ # functions related to noncentral chi-square distribution -import .RFunctions: +# R implementations +using .RFunctions: nchisqpdf, nchisqlogpdf, nchisqcdf, diff --git a/src/distrs/nfdist.jl b/src/distrs/nfdist.jl index e38c2f2..3f2e735 100644 --- a/src/distrs/nfdist.jl +++ b/src/distrs/nfdist.jl @@ -1,6 +1,7 @@ # functions related to noncentral F distribution -import .RFunctions: +# R implementations +using .RFunctions: nfdistpdf, nfdistlogpdf, nfdistcdf, diff --git a/src/distrs/ntdist.jl b/src/distrs/ntdist.jl index 99cae2b..b1bbb1b 100644 --- a/src/distrs/ntdist.jl +++ b/src/distrs/ntdist.jl @@ -1,6 +1,7 @@ # functions related to noncentral T distribution -import .RFunctions: +# R implementations +using .RFunctions: ntdistpdf, ntdistlogpdf, ntdistcdf, diff --git a/src/distrs/pois.jl b/src/distrs/pois.jl index 3caffc2..ed6192b 100644 --- a/src/distrs/pois.jl +++ b/src/distrs/pois.jl @@ -1,8 +1,8 @@ # functions related to Poisson distribution -import .RFunctions: - poispdf, - poislogpdf, +# R implementations +# For pdf and logpdf we use the Julia implementation +using .RFunctions: poiscdf, poisccdf, poislogcdf, @@ -12,19 +12,11 @@ import .RFunctions: poisinvlogcdf, poisinvlogccdf -# generic versions +# Julia implementations poispdf(λ::Real, x::Real) = exp(poislogpdf(λ, x)) -poislogpdf(λ::T, x::T) where {T <: Real} = xlogy(x, λ) - λ - loggamma(x + 1) - -poislogpdf(λ::Number, x::Number) = poislogpdf(promote(float(λ), x)...) - -#= -function poislogpdf(λ::Union{Float32,Float64}, x::Union{Float64,Float32,Integer}) - if iszero(λ) - iszero(x) ? zero(λ) : oftype(λ, -Inf) - elseif iszero(x) - -λ - else - -lstirling_asym(x + 1) -=# +poislogpdf(λ::Real, x::Real) = poislogpdf(promote(λ, x)...) +function poislogpdf(λ::T, x::T) where {T <: Real} + val = xlogy(x, λ) - λ - loggamma(x + 1) + return x >= 0 && isinteger(x) ? val : oftype(val, -Inf) +end diff --git a/src/distrs/srdist.jl b/src/distrs/srdist.jl index 8d1e180..4399f45 100644 --- a/src/distrs/srdist.jl +++ b/src/distrs/srdist.jl @@ -1,6 +1,7 @@ # functions related to studentized range distribution -import .RFunctions: +# R implementations +using .RFunctions: srdistcdf, srdistccdf, srdistlogcdf, diff --git a/src/distrs/tdist.jl b/src/distrs/tdist.jl index 5dc83e9..b7aefbc 100644 --- a/src/distrs/tdist.jl +++ b/src/distrs/tdist.jl @@ -1,8 +1,8 @@ # functions related to student's T distribution -import .RFunctions: - tdistpdf, - tdistlogpdf, +# R implementations +# For pdf and logpdf we use the Julia implementation +using .RFunctions: tdistcdf, tdistccdf, tdistlogcdf, @@ -12,8 +12,11 @@ import .RFunctions: tdistinvlogcdf, tdistinvlogccdf -# pdf for numbers with generic types -tdistpdf(ν::Real, x::Number) = gamma((ν + 1) / 2) / (sqrt(ν * pi) * gamma(ν / 2)) * (1 + x^2 / ν)^(-(ν + 1) / 2) +# Julia implementations +tdistpdf(ν::Real, x::Real) = exp(tdistlogpdf(ν, x)) -# logpdf for numbers with generic types -tdistlogpdf(ν::Real, x::Number) = loggamma((ν + 1) / 2) - log(ν * pi) / 2 - loggamma(ν / 2) + (-(ν + 1) / 2) * log1p(x^2 / ν) +tdistlogpdf(ν::Real, x::Real) = tdistlogpdf(promote(ν, x)...) +function tdistlogpdf(ν::T, x::T) where {T<:Real} + νp12 = (ν + 1) / 2 + return loggamma(νp12) - (logπ + log(ν)) / 2 - loggamma(ν / 2) - νp12 * log1p(x^2 / ν) +end diff --git a/src/rmath.jl b/src/rmath.jl index 0a24ca3..595babf 100644 --- a/src/rmath.jl +++ b/src/rmath.jl @@ -57,7 +57,7 @@ function _import_rmath(rname::Symbol, jname::Symbol, pargs) rtypes = Expr(:tuple, _pts...) end - pdecls = [Expr(:(::), ps, :(Union{Float64,Int})) for ps in pargs] # [:(p1::Union{Float64, Int}), :(p2::Union{...}), ...] + pdecls = [Expr(:(::), ps, :Real) for ps in pargs] # [:(p1::Real), :(p2::Real), ...] if is_tukey # ptukey and qtukey have an extra literal 1 argument @@ -67,36 +67,56 @@ function _import_rmath(rname::Symbol, jname::Symbol, pargs) # Function implementation quote if $(!is_tukey) - $pdf($(pdecls...), x::Union{Float64,Int}) = - ccall(($dfun, libRmath), Float64, $dtypes, x, $(pargs...), 0) - - $logpdf($(pdecls...), x::Union{Float64,Int}) = - ccall(($dfun, libRmath), Float64, $dtypes, x, $(pargs...), 1) + function $pdf($(pdecls...), x::Real) + T = float(Base.promote_typeof($(pargs...), x)) + return convert(T, ccall(($dfun, libRmath), Float64, $dtypes, x, $(pargs...), 0)) + end + + function $logpdf($(pdecls...), x::Real) + T = float(Base.promote_typeof($(pargs...), x)) + return convert(T, ccall(($dfun, libRmath), Float64, $dtypes, x, $(pargs...), 1)) + end end - $cdf($(pdecls...), x::Union{Float64,Int}) = - ccall(($pfun, libRmath), Float64, $ptypes, x, $(pargs...), 1, 0) + function $cdf($(pdecls...), x::Real) + T = float(Base.promote_typeof($(pargs...), x)) + return convert(T, ccall(($pfun, libRmath), Float64, $ptypes, x, $(pargs...), 1, 0)) + end - $ccdf($(pdecls...), x::Union{Float64,Int}) = - ccall(($pfun, libRmath), Float64, $ptypes, x, $(pargs...), 0, 0) + function $ccdf($(pdecls...), x::Real) + T = float(Base.promote_typeof($(pargs...), x)) + return convert(T, ccall(($pfun, libRmath), Float64, $ptypes, x, $(pargs...), 0, 0)) + end - $logcdf($(pdecls...), x::Union{Float64,Int}) = - ccall(($pfun, libRmath), Float64, $ptypes, x, $(pargs...), 1, 1) + function $logcdf($(pdecls...), x::Real) + T = float(Base.promote_typeof($(pargs...), x)) + return convert(T, ccall(($pfun, libRmath), Float64, $ptypes, x, $(pargs...), 1, 1)) + end - $logccdf($(pdecls...), x::Union{Float64,Int}) = - ccall(($pfun, libRmath), Float64, $ptypes, x, $(pargs...), 0, 1) + function $logccdf($(pdecls...), x::Real) + T = float(Base.promote_typeof($(pargs...), x)) + return convert(T, ccall(($pfun, libRmath), Float64, $ptypes, x, $(pargs...), 0, 1)) + end - $invcdf($(pdecls...), q::Union{Float64,Int}) = - ccall(($qfun, libRmath), Float64, $qtypes, q, $(pargs...), 1, 0) + function $invcdf($(pdecls...), q::Real) + T = float(Base.promote_typeof($(pargs...), q)) + return convert(T, ccall(($qfun, libRmath), Float64, $qtypes, q, $(pargs...), 1, 0)) + end - $invccdf($(pdecls...), q::Union{Float64,Int}) = - ccall(($qfun, libRmath), Float64, $qtypes, q, $(pargs...), 0, 0) + function $invccdf($(pdecls...), q::Real) + T = float(Base.promote_typeof($(pargs...), q)) + return convert(T, ccall(($qfun, libRmath), Float64, $qtypes, q, $(pargs...), 0, 0)) + end - $invlogcdf($(pdecls...), lq::Union{Float64,Int}) = - ccall(($qfun, libRmath), Float64, $qtypes, lq, $(pargs...), 1, 1) + function $invlogcdf($(pdecls...), lq::Real) + T = float(Base.promote_typeof($(pargs...), lq)) + return convert(T, ccall(($qfun, libRmath), Float64, $qtypes, lq, $(pargs...), 1, 1)) + end - $invlogccdf($(pdecls...), lq::Union{Float64,Int}) = - ccall(($qfun, libRmath), Float64, $qtypes, lq, $(pargs...), 0, 1) + function $invlogccdf($(pdecls...), lq::Real) + T = float(Base.promote_typeof($(pargs...), lq)) + return convert(T, ccall(($qfun, libRmath), Float64, $qtypes, lq, $(pargs...), 0, 1)) + end if $has_rand $rand($(pdecls...)) = @@ -109,7 +129,6 @@ macro import_rmath(rname, jname, pargs...) esc(_import_rmath(rname, jname, pargs)) end - ### Import specific functions @import_rmath beta beta α β diff --git a/test/chainrules.jl b/test/chainrules.jl index e469ebf..d59b5ea 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -32,25 +32,22 @@ using Random test_frule(tdistlogpdf, x, y) test_rrule(tdistlogpdf, x, y) - # use `BigFloat` to avoid Rmath implementation in finite differencing check - # (returns `NaN` for non-integer values) - n = rand(1:100) - x = BigFloat(n) - y = big(logistic(randn())) - z = BigFloat(rand(1:n)) + x = rand(1:100) + y = logistic(randn()) + z = rand(1:x) test_frule(binomlogpdf, x, y, z) test_rrule(binomlogpdf, x, y, z) - x = big(exp(randn())) - y = BigFloat(rand(1:100)) + x = exp(randn()) + y = rand(1:100) test_frule(poislogpdf, x, y) test_rrule(poislogpdf, x, y) # test special case λ = 0 - _, pb = rrule(StatsFuns.poislogpdf, 0.0, 0.0) + _, pb = rrule(poislogpdf, 0.0, 0) _, x̄1, _ = pb(1) @test x̄1 == -1 - _, pb = rrule(StatsFuns.poislogpdf, 0.0, 1.0) + _, pb = rrule(poislogpdf, 0.0, 1) _, x̄1, _ = pb(1) @test x̄1 == Inf end diff --git a/test/generic.jl b/test/generic.jl index 1b4bf0e..473d6cd 100644 --- a/test/generic.jl +++ b/test/generic.jl @@ -1,9 +1,10 @@ using StatsFuns +using StatsFuns: RFunctions using ForwardDiff: Dual -function check_rmath(fname, statsfun, params, aname, a, isprob, rtol) - v = statsfun(params..., a) - rv = statsfun(params..., Dual(a)).value +function check_rmath(fname, statsfun, rmathfun, params, aname, a, isprob, rtol) + v = @inferred(rmathfun(params..., a)) + rv = @inferred(statsfun(params..., Dual(a))).value if isprob rd = abs(v / rv - 1.0) if rd > rtol @@ -27,10 +28,12 @@ function genericcomp(basename, params, X::AbstractArray, rtol=100eps(float(one(e logpdf = string(basename, "logpdf") stats_pdf = eval(Symbol(pdf)) stats_logpdf = eval(Symbol(logpdf)) + rmath_pdf = eval(Meta.parse(string("RFunctions.", pdf))) + rmath_logpdf = eval(Meta.parse(string("RFunctions.", logpdf))) for i = 1:length(X) x = X[i] - check_rmath(pdf, stats_pdf, params, "x", x, true, rtol) - check_rmath(logpdf, stats_logpdf, params, "x", x, false, rtol) + check_rmath(pdf, stats_pdf, rmath_pdf, params, "x", x, true, rtol) + check_rmath(logpdf, stats_logpdf, rmath_logpdf, params, "x", x, false, rtol) end end diff --git a/test/rmath.jl b/test/rmath.jl index 138cce8..f3c5829 100644 --- a/test/rmath.jl +++ b/test/rmath.jl @@ -1,9 +1,12 @@ -using StatsFuns, Test -import StatsFuns.RFunctions +using StatsFuns +using StatsFuns: RFunctions +using Test function check_rmath(fname, statsfun, rmathfun, params, aname, a, isprob, rtol) - v = statsfun(params..., a) - rv = rmathfun(params..., a) + v = @inferred(statsfun(params..., a)) + rv = @inferred(rmathfun(params..., a)) + @test v isa float(Base.promote_typeof(params..., a)) + @test rv isa float(Base.promote_typeof(params..., a)) if isprob @test v ≈ rv rtol=rtol nans=true else @@ -116,12 +119,35 @@ end @testset "RMath" begin rmathcomp_tests("beta", [ - ((1.0, 1.0), 0.01:0.01:0.99), - ((2.0, 3.0), 0.01:0.01:0.99), - ((10.0, 2.0), 0.01:0.01:0.99), - ((10, 2), 0.01:0.01:0.99), + ((0.1, 1.0), 0.0:0.01:1.0), + ((1.0, 1.0), 0.0:0.01:1.0), + ((2.0, 3.0), 0.0:0.01:1.0), + ((10.0, 2.0), 0.0:0.01:1.0), + ((10, 2), 0.0:0.01:1.0), + ((1f0, 1f0), 0f0:0.01f0:1f0), + ((1.0, 1.0), 0f0:0.01f0:1f0), + ((Float16(1), Float16(1)), Float16(0):Float16(0.01):Float16(1)), + ((1f0, 1f0), Float16(0):Float16(0.01):Float16(1)), + ((10, 2), [0, 1]), + ((10, 2), 0//1:1//100:1//1), ]) + # We test the following extreme parameters separately since + # a slightly larger tolerance is needed. + # + # For `betapdf(1000, 2, 0.58)`: + # StatsFuns: 1.9419987107407202e-231 + # Rmath: 1.941998710740941e-231 + # Mathematica: 1.941998710742487e-231 + # For `betapdf(1000, 2, 0.68)`: + # StatsFuns: 1.5205049885199752e-162 + # Rmath: 1.5205049885200616e-162 + # Mathematica: 1.520504988521358e-162 + @testset "Beta(1000, 2)" begin + rmathcomp("beta", (1000, 2), setdiff(0.0:0.01:1.0, (0.58, 0.68))) + rmathcomp("beta", (1000, 2), [0.58, 0.68], 1e-12) + end + rmathcomp_tests("binom", [ ((1, 0.5), 0.0:1.0), ((1, 0.7), 0.0:1.0), @@ -129,6 +155,11 @@ end ((20, 0.1), 0.0:20.0), ((20, 0.9), 0.0:20.0), ((20, 0.9), 0:20), + ((1, 0.5f0), 0f0:1f0), + ((1, 0.5), 0f0:1f0), + ((1, Float16(0.5)), Float16(0):Float16(1)), + ((1, 0.5f0), Float16(0):Float16(1)), + ((10, 1//2), 0//1:10//1), ]) rmathcomp_tests("chisq", [ @@ -136,6 +167,9 @@ end ((4,), 0.0:0.1:8.0), ((9,), 0.0:0.1:8.0), ((9,), 0:8), + ((1,), 0f0:0.1f0:8f0), + ((1,), Float16(0):Float16(0.1):Float16(8)), + ((9,), 0//1:8//1), ]) rmathcomp_tests("fdist", [ @@ -145,20 +179,31 @@ end ((10, 1), (0.0:0.1:5.0)), ((10, 3), (0.0:0.1:5.0)), ((10, 3), (0:5)), + ((1, 1), (0f0:0.1f0:5f0)), + ((1, 1), (Float16(0):Float16(0.1):Float16(5))), + ((10, 3), 0//1:5//1), ]) rmathcomp_tests("gamma", [ - ((1.0, 1.0), (0.05:0.05:12.0)), - ((0.5, 1.0), (0.05:0.05:12.0)), - ((3.0, 1.0), (0.05:0.05:12.0)), - ((9.0, 1.0), (0.05:0.05:12.0)), - ((2.0, 3.0), (0.05:0.05:12.0)), - ((2, 3), (1:12)), + ((1.0, 1.0), (0.0:0.05:12.0)), + ((0.5, 1.0), (0.0:0.05:12.0)), + ((3.0, 1.0), (0.0:0.05:12.0)), + ((9.0, 1.0), (0.0:0.05:12.0)), + ((2.0, 3.0), (0.0:0.05:12.0)), + ((2, 3), (0:12)), + ((1f0, 1f0), (0f0:0.05f0:12f0)), + ((1.0, 1.0), (0f0:0.05f0:12f0)), + ((Float16(1), Float16(1)), (Float16(0):Float16(0.05):Float16(12))), + ((1f0, 1f0), (Float16(0):Float16(0.05):Float16(12))), + ((2, 3), (0//1:12//1)), ]) rmathcomp_tests("hyper", [ ((2, 3, 4), 0.0:4.0), - ((2, 3, 4), 0:4) + ((2, 3, 4), 0:4), + ((2, 3, 4), 0f0:4f0), + ((2, 3, 4), Float16(0):Float16(4)), + ((2, 3, 4), 0//1:4//1), ]) rmathcomp_tests("nbeta", [ @@ -167,6 +212,11 @@ end ((1.0, 1.0, 2.0), 0.01:0.01:0.99), ((3.0, 4.0, 2.0), 0.01:0.01:0.99), ((3, 4, 2), 0.01:0.01:0.99), + ((1f0, 1f0, 0f0), 0.01f0:0.01f0:0.99f0), + ((1.0, 1.0, 0.0), 0.01f0:0.01f0:0.99f0), + ((Float16(1), Float16(1), Float16(0)), Float16(0.01):Float16(0.01):Float16(0.99)), + ((1f0, 1f0, 0f0), Float16(0.01):Float16(0.01):Float16(0.99)), + ((3, 4, 2), 1//100:1//100:99//100), ]) rmathcomp_tests("nbinom", [ @@ -174,7 +224,11 @@ end ((3, 0.5), 0.0:20.0), ((3, 0.2), 0.0:20.0), ((3, 0.8), 0.0:20.0), + ((1, 0.5f0), 0f0:20f0), + ((1, 0.5), 0f0:20f0), + ((1, Float16(0.5)), Float16(0):Float16(20)), ((3, 0.8), 0:20), + ((3, 1//2), 0//1:20//1), ]) rmathcomp_tests("nchisq", [ @@ -183,6 +237,9 @@ end ((4, 1), 0.0:0.2:8.0), ((4, 3), 0.0:0.2:8.0), ((4, 3), 0:8), + ((2, 1), 0f0:0.2f0:8f0), + ((2, 1), Float16(0):Float16(0.2):Float16(8)), + ((2, 1), 0//1:1//5:8//1), ]) rmathcomp_tests("nfdist", [ @@ -190,6 +247,11 @@ end ((1.0, 1.0, 2.0), 0.1:0.1:10.0), ((2.0, 3.0, 1.0), 0.1:0.1:10.0), ((2, 3, 1), 1:10), + ((1f0, 1f0, 0f0), 0.1f0:0.1f0:10f0), + ((1.0, 1.0, 0.0), 0.1f0:0.1f0:10f0), + ((Float16(1), Float16(1), Float16(0)), Float16(0.1):Float16(0.1):Float16(10)), + ((1f0, 1f0, 0f0), Float16(0.1):Float16(0.1):Float16(10)), + ((2, 3, 1), 1//1:10//1), ]) rmathcomp_tests("norm", [ @@ -200,6 +262,11 @@ end ((0, 1), -3.0:0.01:3.0), ((0, 1), -3:3), ((0.0, 0.0), -6.0:0.1:6.0), + ((0f0, 1f0), -6f0:0.01f0:6f0), + ((0.0, 1.0), -6f0:0.01f0:6f0), + # Fail since `SpecialFunctions.erfcx` is not implemented for `Float16` + #((Float16(0), Float16(1)), -Float16(6):Float16(0.01):Float16(6)), + #((0f0, 1f0), -Float16(6):Float16(0.01):Float16(6)), ]) rmathcomp_tests("ntdist", [ @@ -208,6 +275,9 @@ end ((2, 1), -4.0:0.1:10.0), ((2, 4), -4.0:0.1:10.0), ((2, 4), -4:10), + ((0, 1), -4f0:0.1f0:10f0), + ((0, 1), -Float16(4):Float16(0.1):Float16(10)), + ((0, 1), -4//1:1//10:10//1), ]) rmathcomp_tests("pois", [ @@ -216,6 +286,11 @@ end ((2.0,), 0.0:20.0), ((10.0,), 0.0:20.0), ((10,), 0:20), + ((0.5f0,), 0f0:20f0), + ((0.5,), 0f0:20f0), + ((Float16(0.5),), Float16(0):Float16(20)), + ((0.5f0,), Float16(0):Float16(20)), + ((1//2,), 0//1:20//1), ]) rmathcomp_tests("tdist", [ @@ -223,6 +298,9 @@ end ((2,), -5.0:0.1:5.0), ((5,), -5.0:0.1:5.0), ((5,), -5:5), + ((1,), -5f0:0.1f0:5f0), + ((1,), -Float16(5):Float16(0.1):Float16(5)), + ((1,), -5//1:5//1), ]) rmathcomp_tests("srdist", [ @@ -230,7 +308,10 @@ end ((2,2), (0.0:0.2:5.0)), ((5,3), (0.0:0.2:5.0)), ((10,2), (0.0:0.2:5.0)), - ((10,5), (0.0:0.2:5.0)) + ((10,5), (0.0:0.2:5.0)), + ((1,2), (0f0:0.2f0:5f0)), + ((1,2), (Float16(0):Float16(0.2):Float16(5))), + ((1,2), (0//1:1//5:5//1)), ]) # Note: Convergence fails in srdist with cdf values below 0.16 with df = 10, k = 5. @@ -240,4 +321,31 @@ end rx = srdistinvcdf(10, 5, q) rtol = 100eps(1.0) @test_broken x ≈ rx atol=rtol rtol=rtol nans=true + + # Test values outside of the support + rmathcomp_tests("beta", [ + ((1.0, 1.0), [-10.0, -6.3, 2.1, 23.5]), + ((1//1, 1//1), [-10//1, -63//10, 21//10, 47//2]), + ((1, 1), [-10, -6, 2, 24]), + ]) + rmathcomp_tests("binom", [ + ((5, 0.5), [-8, -2.3, 1.2, 5.4, 11.9]), + ((5, 1//2), [-8, -23//10, 6//5, 27//5, 119//10]), + ((5, 1//2), [-8, -2, 6, 12]), + ]) + rmathcomp_tests("fdist", [ + ((1.0, 1.0), [-10.0, -6.3]), + ((1//1, 1//1), [-10//1, -63//10]), + ((1, 1), [-10, -6]), + ]) + rmathcomp_tests("gamma", [ + ((1.0, 1.0), [-10.0, -6.3]), + ((1//1, 1//1), [-10//1, -63//10]), + ((1, 1), [-10, -6]), + ]) + rmathcomp_tests("pois", [ + ((0.5,), [-10, -2.5, 1.3, 8.7]), + ((1//2,), [-10, -5//2, 13//10, 87//10]), + ((1,), [-10, -3]), + ]) end