Skip to content

Commit

Permalink
Allow Real arguments in R functions and use Julia implementations o…
Browse files Browse the repository at this point in the history
…nly (#125)

* Allow arguments of type `Float32` and `Float16` in R functions

* Add tests

* Bump version

* Allow `Real` if no Julia fallback exists and `Int`s, `UInt`s, and `Rational`s otherwise

* Do not use R implementation when Julia implementation is available

* Add more tests
  • Loading branch information
devmotion authored Sep 28, 2021
1 parent fe65c00 commit 53e8888
Show file tree
Hide file tree
Showing 22 changed files with 279 additions and 130 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StatsFuns"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.10"
version = "0.9.11"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -13,7 +13,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
[compat]
ChainRulesCore = "1"
IrrationalConstants = "0.1"
LogExpFunctions = "0.3"
LogExpFunctions = "0.3.2"
Reexport = "1"
Rmath = "0.4, 0.5, 0.6, 0.7"
SpecialFunctions = "0.8, 0.9, 0.10, 1.0"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ log4π, # log(4π)
# basicfuns
xlogx, # x * log(x), or 0 when x is zero
xlogy, # x * log(y), or 0 when x is zero
xlog1py, # x * log(1 + y) for x > 0, or 0 when x == 0
logistic, # 1 / (1 + exp(-x))
logit, # log(x / (1 - x))
log1psq, # log(1 + x^2)
Expand Down
1 change: 1 addition & 0 deletions src/StatsFuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import ChainRulesCore
@reexport using LogExpFunctions:
xlogx, # x * log(x) for x > 0, or 0 when x == 0
xlogy, # x * log(y) for x > 0, or 0 when x == 0
xlog1py, # x * log(1 + y) for x > 0, or 0 when x == 0
logistic, # 1 / (1 + exp(-x))
logit, # log(x / (1 - x))
log1psq, # log(1 + x^2)
Expand Down
6 changes: 3 additions & 3 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ ChainRulesCore.@scalar_rule(
binomlogpdf(n::Real, p::Real, k::Real),
@setup(z = digamma(n - k + 1)),
(
digamma(n + 2) - z + log1p(-p) - 1 / (1 + n),
ChainRulesCore.NoTangent(),
(k / p - n) / (1 - p),
z - digamma(k + 1) + logit(p),
ChainRulesCore.NoTangent(),
),
)

Expand Down Expand Up @@ -59,7 +59,7 @@ ChainRulesCore.@scalar_rule(

ChainRulesCore.@scalar_rule(
poislogpdf::Number, x::Number),
((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, log(λ) - digamma(x + 1)),
((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, ChainRulesCore.NoTangent()),
)

ChainRulesCore.@scalar_rule(
Expand Down
19 changes: 12 additions & 7 deletions src/distrs/beta.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# functions related to beta distributions

import .RFunctions:
betapdf,
betalogpdf,
# R implementations
# For pdf and logpdf we use the Julia implementation
using .RFunctions:
betacdf,
betaccdf,
betalogcdf,
Expand All @@ -12,8 +12,13 @@ import .RFunctions:
betainvlogcdf,
betainvlogccdf

# pdf for numbers with generic types
betapdf::Real, β::Real, x::Number) = x^- 1) * (1 - x)^- 1) / beta(α, β)
# Julia implementations
betapdf::Real, β::Real, x::Real) = exp(betalogpdf(α, β, x))

# logpdf for numbers with generic types
betalogpdf::Real, β::Real, x::Number) =- 1) * log(x) +- 1) * log1p(-x) - logbeta(α, β)
betalogpdf::Real, β::Real, x::Real) = betalogpdf(promote(α, β, x)...)
function betalogpdf::T, β::T, x::T) where {T<:Real}
# we ensure that `log(x)` and `log1p(-x)` do not error
y = clamp(x, 0, 1)
val = xlogy- 1, y) + xlog1py- 1, -y) - logbeta(α, β)
return x < 0 || x > 1 ? oftype(val, -Inf) : val
end
17 changes: 11 additions & 6 deletions src/distrs/binom.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# functions related to binomial distribution

import .RFunctions:
binompdf,
binomlogpdf,
# R implementations
# For pdf and logpdf we use the Julia implementation
using .RFunctions:
binomcdf,
binomccdf,
binomlogcdf,
Expand All @@ -12,8 +12,13 @@ import .RFunctions:
binominvlogcdf,
binominvlogccdf

# pdf for numbers with generic types

# Julia implementations
binompdf(n::Real, p::Real, k::Real) = exp(binomlogpdf(n, p, k))

# logpdf for numbers with generic types
binomlogpdf(n::Real, p::Real, k::Real) = -log1p(n) - logbeta(n - k + 1, k + 1) + k * log(p) + (n - k) * log1p(-p)
binomlogpdf(n::Real, p::Real, k::Real) = binomlogpdf(promote(n, p, k)...)
function binomlogpdf(n::T, p::T, k::T) where {T<:Real}
m = clamp(k, 0, n)
val = betalogpdf(m + 1, n - m + 1, p) - log(n + 1)
return 0 <= k <= n && isinteger(k) ? val : oftype(val, -Inf)
end
23 changes: 10 additions & 13 deletions src/distrs/chisq.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# functions related to chi-square distribution

import .RFunctions:
chisqpdf,
chisqlogpdf,
# R implementations
# For pdf and logpdf we use the Julia implementation
using .RFunctions:
chisqcdf,
chisqccdf,
chisqlogcdf,
Expand All @@ -12,14 +12,11 @@ import .RFunctions:
chisqinvlogcdf,
chisqinvlogccdf

# pdf for numbers with generic types
function chisqpdf(k::Real, x::Number)
hk = k / 2 # half k
1 / (2^(hk) * gamma(hk)) * x^(hk - 1) * exp(-x / 2)
end
# Julia implementations
# promotion ensures that we do forward e.g. `chisqpdf(::Int, ::Float32)` to
# `gammapdf(::Float32, ::Int, ::Float32)` but not `gammapdf(::Float64, ::Int, ::Float32)`
chisqpdf(k::Real, x::Real) = chisqpdf(promote(k, x)...)
chisqpdf(k::T, x::T) where {T<:Real} = gammapdf(k / 2, 2, x)

# logpdf for numbers with generic types
function chisqlogpdf(k::Real, x::Number)
hk = k / 2 # half k
-hk * logtwo - loggamma(hk) + (hk - 1) * log(x) - x / 2
end
chisqlogpdf(k::Real, x::Real) = chisqlogpdf(promote(k, x)...)
chisqlogpdf(k::T, x::T) where {T<:Real} = gammalogpdf(k / 2, 2, x)
20 changes: 13 additions & 7 deletions src/distrs/fdist.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# functions related to F distribution

import .RFunctions:
fdistpdf,
fdistlogpdf,
# R implementations
# For pdf and logpdf we use the Julia implementation
using .RFunctions:
fdistcdf,
fdistccdf,
fdistlogcdf,
Expand All @@ -12,8 +12,14 @@ import .RFunctions:
fdistinvlogcdf,
fdistinvlogccdf

# pdf for numbers with generic types
fdistpdf(ν1::Real, ν2::Real, x::Number) = sqrt((ν1 * x)^ν1 * ν2^ν2 / (ν1 * x + ν2)^(ν1 + ν2)) / (x * beta(ν1 / 2, ν2 / 2))
# Julia implementations
fdistpdf(ν1::Real, ν2::Real, x::Real) = exp(fdistlogpdf(ν1, ν2, x))

# logpdf for numbers with generic types
fdistlogpdf(ν1::Real, ν2::Real, x::Number) = (ν1 * log(ν1 * x) + ν2 * log(ν2) - (ν1 + ν2) * log(ν1 * x + ν2)) / 2 - log(x) - logbeta(ν1 / 2, ν2 / 2)
fdistlogpdf(ν1::Real, ν2::Real, x::Real) = fdistlogpdf(promote(ν1, ν2, x)...)
function fdistlogpdf(ν1::T, ν2::T, x::T) where {T<:Real}
# we ensure that `log(x)` does not error if `x < 0`
ν1ν2 = ν1 / ν2
y = max(x, 0)
val = (xlogy(ν1, ν1ν2) + xlogy(ν1 - 2, y) - xlogy(ν1 + ν2, 1 + ν1ν2 * y)) / 2 - logbeta(ν1 / 2, ν2 / 2)
return x < 0 ? oftype(val, -Inf) : val
end
19 changes: 12 additions & 7 deletions src/distrs/gamma.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# functions related to gamma distribution

import .RFunctions:
gammapdf,
gammalogpdf,
# R implementations
# For pdf and logpdf we use the Julia implementation
using .RFunctions:
gammacdf,
gammaccdf,
gammalogcdf,
Expand All @@ -12,8 +12,13 @@ import .RFunctions:
gammainvlogcdf,
gammainvlogccdf

# pdf for numbers with generic types
gammapdf(k::Real, θ::Real, x::Number) = 1 / (gamma(k) * θ^k) * x^(k - 1) * exp(-x / θ)
# Julia implementations
gammapdf(k::Real, θ::Real, x::Real) = exp(gammalogpdf(k, θ, x))

# logpdf for numbers with generic types
gammalogpdf(k::Real, θ::Real, x::Number) = -loggamma(k) - k * log(θ) + (k - 1) * log(x) - x / θ
gammalogpdf(k::Real, θ::Real, x::Real) = gammalogpdf(promote(k, θ, x)...)
function gammalogpdf(k::T, θ::T, x::T) where {T<:Real}
# we ensure that `log(x)` does not error if `x < 0`
= max(x, 0) / θ
val = -loggamma(k) + xlogy(k - 1, xθ) - log(θ) -
return x < 0 ? oftype(val, -Inf) : val
end
3 changes: 2 additions & 1 deletion src/distrs/hyper.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# functions related to hyper-geometric distribution

import .RFunctions:
# R implementations
using .RFunctions:
hyperpdf,
hyperlogpdf,
hypercdf,
Expand Down
3 changes: 2 additions & 1 deletion src/distrs/nbeta.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# functions related to noncentral beta distribution

import .RFunctions:
# R implementations
using .RFunctions:
nbetapdf,
nbetalogpdf,
nbetacdf,
Expand Down
3 changes: 2 additions & 1 deletion src/distrs/nbinom.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# functions related to negative binomial distribution

import .RFunctions:
# R implementations
using .RFunctions:
nbinompdf,
nbinomlogpdf,
nbinomcdf,
Expand Down
3 changes: 2 additions & 1 deletion src/distrs/nchisq.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# functions related to noncentral chi-square distribution

import .RFunctions:
# R implementations
using .RFunctions:
nchisqpdf,
nchisqlogpdf,
nchisqcdf,
Expand Down
3 changes: 2 additions & 1 deletion src/distrs/nfdist.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# functions related to noncentral F distribution

import .RFunctions:
# R implementations
using .RFunctions:
nfdistpdf,
nfdistlogpdf,
nfdistcdf,
Expand Down
3 changes: 2 additions & 1 deletion src/distrs/ntdist.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# functions related to noncentral T distribution

import .RFunctions:
# R implementations
using .RFunctions:
ntdistpdf,
ntdistlogpdf,
ntdistcdf,
Expand Down
26 changes: 9 additions & 17 deletions src/distrs/pois.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# functions related to Poisson distribution

import .RFunctions:
poispdf,
poislogpdf,
# R implementations
# For pdf and logpdf we use the Julia implementation
using .RFunctions:
poiscdf,
poisccdf,
poislogcdf,
Expand All @@ -12,19 +12,11 @@ import .RFunctions:
poisinvlogcdf,
poisinvlogccdf

# generic versions
# Julia implementations
poispdf::Real, x::Real) = exp(poislogpdf(λ, x))

poislogpdf::T, x::T) where {T <: Real} = xlogy(x, λ) - λ - loggamma(x + 1)

poislogpdf::Number, x::Number) = poislogpdf(promote(float(λ), x)...)

#=
function poislogpdf(λ::Union{Float32,Float64}, x::Union{Float64,Float32,Integer})
if iszero(λ)
iszero(x) ? zero(λ) : oftype(λ, -Inf)
elseif iszero(x)
else
-lstirling_asym(x + 1)
=#
poislogpdf::Real, x::Real) = poislogpdf(promote(λ, x)...)
function poislogpdf::T, x::T) where {T <: Real}
val = xlogy(x, λ) - λ - loggamma(x + 1)
return x >= 0 && isinteger(x) ? val : oftype(val, -Inf)
end
3 changes: 2 additions & 1 deletion src/distrs/srdist.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# functions related to studentized range distribution

import .RFunctions:
# R implementations
using .RFunctions:
srdistcdf,
srdistccdf,
srdistlogcdf,
Expand Down
17 changes: 10 additions & 7 deletions src/distrs/tdist.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# functions related to student's T distribution

import .RFunctions:
tdistpdf,
tdistlogpdf,
# R implementations
# For pdf and logpdf we use the Julia implementation
using .RFunctions:
tdistcdf,
tdistccdf,
tdistlogcdf,
Expand All @@ -12,8 +12,11 @@ import .RFunctions:
tdistinvlogcdf,
tdistinvlogccdf

# pdf for numbers with generic types
tdistpdf::Real, x::Number) = gamma((ν + 1) / 2) / (sqrt* pi) * gamma/ 2)) * (1 + x^2 / ν)^(-+ 1) / 2)
# Julia implementations
tdistpdf::Real, x::Real) = exp(tdistlogpdf(ν, x))

# logpdf for numbers with generic types
tdistlogpdf::Real, x::Number) = loggamma((ν + 1) / 2) - log* pi) / 2 - loggamma/ 2) + (-+ 1) / 2) * log1p(x^2 / ν)
tdistlogpdf::Real, x::Real) = tdistlogpdf(promote(ν, x)...)
function tdistlogpdf::T, x::T) where {T<:Real}
νp12 =+ 1) / 2
return loggamma(νp12) - (logπ + log(ν)) / 2 - loggamma/ 2) - νp12 * log1p(x^2 / ν)
end
Loading

2 comments on commit 53e8888

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/45672

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.11 -m "<description of version>" 53e888849e6812dc03bbfe677d8bc9833b1aa978
git push origin v0.9.11

Please sign in to comment.