Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Real arguments in R functions and use Julia implementations only #125

Merged
merged 6 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StatsFuns"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.10"
version = "0.9.11"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -13,7 +13,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
[compat]
ChainRulesCore = "1"
IrrationalConstants = "0.1"
LogExpFunctions = "0.3"
LogExpFunctions = "0.3.2"
Reexport = "1"
Rmath = "0.4, 0.5, 0.6, 0.7"
SpecialFunctions = "0.8, 0.9, 0.10, 1.0"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ log4π, # log(4π)
# basicfuns
xlogx, # x * log(x), or 0 when x is zero
xlogy, # x * log(y), or 0 when x is zero
xlog1py, # x * log(1 + y) for x > 0, or 0 when x == 0
logistic, # 1 / (1 + exp(-x))
logit, # log(x / (1 - x))
log1psq, # log(1 + x^2)
Expand Down
1 change: 1 addition & 0 deletions src/StatsFuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import ChainRulesCore
@reexport using LogExpFunctions:
xlogx, # x * log(x) for x > 0, or 0 when x == 0
xlogy, # x * log(y) for x > 0, or 0 when x == 0
xlog1py, # x * log(1 + y) for x > 0, or 0 when x == 0
logistic, # 1 / (1 + exp(-x))
logit, # log(x / (1 - x))
log1psq, # log(1 + x^2)
Expand Down
6 changes: 3 additions & 3 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ ChainRulesCore.@scalar_rule(
binomlogpdf(n::Real, p::Real, k::Real),
@setup(z = digamma(n - k + 1)),
(
digamma(n + 2) - z + log1p(-p) - 1 / (1 + n),
ChainRulesCore.NoTangent(),
nalimilan marked this conversation as resolved.
Show resolved Hide resolved
(k / p - n) / (1 - p),
z - digamma(k + 1) + logit(p),
ChainRulesCore.NoTangent(),
),
)

Expand Down Expand Up @@ -59,7 +59,7 @@ ChainRulesCore.@scalar_rule(

ChainRulesCore.@scalar_rule(
poislogpdf(λ::Number, x::Number),
((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, log(λ) - digamma(x + 1)),
((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, ChainRulesCore.NoTangent()),
)

ChainRulesCore.@scalar_rule(
Expand Down
19 changes: 12 additions & 7 deletions src/distrs/beta.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# functions related to beta distributions

import .RFunctions:
betapdf,
betalogpdf,
# R implementations
# For pdf and logpdf we use the Julia implementation
using .RFunctions:
betacdf,
betaccdf,
betalogcdf,
Expand All @@ -12,8 +12,13 @@ import .RFunctions:
betainvlogcdf,
betainvlogccdf

# pdf for numbers with generic types
betapdf(α::Real, β::Real, x::Number) = x^(α - 1) * (1 - x)^(β - 1) / beta(α, β)
# Julia implementations
betapdf(α::Real, β::Real, x::Real) = exp(betalogpdf(α, β, x))

# logpdf for numbers with generic types
betalogpdf(α::Real, β::Real, x::Number) = (α - 1) * log(x) + (β - 1) * log1p(-x) - logbeta(α, β)
betalogpdf(α::Real, β::Real, x::Real) = betalogpdf(promote(α, β, x)...)
function betalogpdf(α::T, β::T, x::T) where {T<:Real}
# we ensure that `log(x)` and `log1p(-x)` do not error
y = clamp(x, 0, 1)
val = xlogy(α - 1, y) + xlog1py(β - 1, -y) - logbeta(α, β)
return x < 0 || x > 1 ? oftype(val, -Inf) : val
end
17 changes: 11 additions & 6 deletions src/distrs/binom.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# functions related to binomial distribution

import .RFunctions:
binompdf,
binomlogpdf,
# R implementations
# For pdf and logpdf we use the Julia implementation
using .RFunctions:
binomcdf,
binomccdf,
binomlogcdf,
Expand All @@ -12,8 +12,13 @@ import .RFunctions:
binominvlogcdf,
binominvlogccdf

# pdf for numbers with generic types

# Julia implementations
binompdf(n::Real, p::Real, k::Real) = exp(binomlogpdf(n, p, k))

# logpdf for numbers with generic types
binomlogpdf(n::Real, p::Real, k::Real) = -log1p(n) - logbeta(n - k + 1, k + 1) + k * log(p) + (n - k) * log1p(-p)
binomlogpdf(n::Real, p::Real, k::Real) = binomlogpdf(promote(n, p, k)...)
function binomlogpdf(n::T, p::T, k::T) where {T<:Real}
m = clamp(k, 0, n)
val = betalogpdf(m + 1, n - m + 1, p) - log(n + 1)
return 0 <= k <= n && isinteger(k) ? val : oftype(val, -Inf)
end
23 changes: 10 additions & 13 deletions src/distrs/chisq.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# functions related to chi-square distribution

import .RFunctions:
chisqpdf,
chisqlogpdf,
# R implementations
# For pdf and logpdf we use the Julia implementation
using .RFunctions:
chisqcdf,
chisqccdf,
chisqlogcdf,
Expand All @@ -12,14 +12,11 @@ import .RFunctions:
chisqinvlogcdf,
chisqinvlogccdf

# pdf for numbers with generic types
function chisqpdf(k::Real, x::Number)
hk = k / 2 # half k
1 / (2^(hk) * gamma(hk)) * x^(hk - 1) * exp(-x / 2)
end
# Julia implementations
# promotion ensures that we do forward e.g. `chisqpdf(::Int, ::Float32)` to
# `gammapdf(::Float32, ::Int, ::Float32)` but not `gammapdf(::Float64, ::Int, ::Float32)`
chisqpdf(k::Real, x::Real) = chisqpdf(promote(k, x)...)
chisqpdf(k::T, x::T) where {T<:Real} = gammapdf(k / 2, 2, x)

# logpdf for numbers with generic types
function chisqlogpdf(k::Real, x::Number)
hk = k / 2 # half k
-hk * logtwo - loggamma(hk) + (hk - 1) * log(x) - x / 2
end
chisqlogpdf(k::Real, x::Real) = chisqlogpdf(promote(k, x)...)
chisqlogpdf(k::T, x::T) where {T<:Real} = gammalogpdf(k / 2, 2, x)
20 changes: 13 additions & 7 deletions src/distrs/fdist.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# functions related to F distribution

import .RFunctions:
fdistpdf,
fdistlogpdf,
# R implementations
# For pdf and logpdf we use the Julia implementation
using .RFunctions:
fdistcdf,
fdistccdf,
fdistlogcdf,
Expand All @@ -12,8 +12,14 @@ import .RFunctions:
fdistinvlogcdf,
fdistinvlogccdf

# pdf for numbers with generic types
fdistpdf(ν1::Real, ν2::Real, x::Number) = sqrt((ν1 * x)^ν1 * ν2^ν2 / (ν1 * x + ν2)^(ν1 + ν2)) / (x * beta(ν1 / 2, ν2 / 2))
# Julia implementations
fdistpdf(ν1::Real, ν2::Real, x::Real) = exp(fdistlogpdf(ν1, ν2, x))

# logpdf for numbers with generic types
fdistlogpdf(ν1::Real, ν2::Real, x::Number) = (ν1 * log(ν1 * x) + ν2 * log(ν2) - (ν1 + ν2) * log(ν1 * x + ν2)) / 2 - log(x) - logbeta(ν1 / 2, ν2 / 2)
fdistlogpdf(ν1::Real, ν2::Real, x::Real) = fdistlogpdf(promote(ν1, ν2, x)...)
function fdistlogpdf(ν1::T, ν2::T, x::T) where {T<:Real}
# we ensure that `log(x)` does not error if `x < 0`
ν1ν2 = ν1 / ν2
y = max(x, 0)
val = (xlogy(ν1, ν1ν2) + xlogy(ν1 - 2, y) - xlogy(ν1 + ν2, 1 + ν1ν2 * y)) / 2 - logbeta(ν1 / 2, ν2 / 2)
return x < 0 ? oftype(val, -Inf) : val
end
19 changes: 12 additions & 7 deletions src/distrs/gamma.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# functions related to gamma distribution

import .RFunctions:
gammapdf,
gammalogpdf,
# R implementations
# For pdf and logpdf we use the Julia implementation
using .RFunctions:
gammacdf,
gammaccdf,
gammalogcdf,
Expand All @@ -12,8 +12,13 @@ import .RFunctions:
gammainvlogcdf,
gammainvlogccdf

# pdf for numbers with generic types
gammapdf(k::Real, θ::Real, x::Number) = 1 / (gamma(k) * θ^k) * x^(k - 1) * exp(-x / θ)
# Julia implementations
gammapdf(k::Real, θ::Real, x::Real) = exp(gammalogpdf(k, θ, x))

# logpdf for numbers with generic types
gammalogpdf(k::Real, θ::Real, x::Number) = -loggamma(k) - k * log(θ) + (k - 1) * log(x) - x / θ
gammalogpdf(k::Real, θ::Real, x::Real) = gammalogpdf(promote(k, θ, x)...)
function gammalogpdf(k::T, θ::T, x::T) where {T<:Real}
# we ensure that `log(x)` does not error if `x < 0`
xθ = max(x, 0) / θ
val = -loggamma(k) + xlogy(k - 1, xθ) - log(θ) - xθ
return x < 0 ? oftype(val, -Inf) : val
end
3 changes: 2 additions & 1 deletion src/distrs/hyper.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# functions related to hyper-geometric distribution

import .RFunctions:
# R implementations
using .RFunctions:
hyperpdf,
hyperlogpdf,
hypercdf,
Expand Down
3 changes: 2 additions & 1 deletion src/distrs/nbeta.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# functions related to noncentral beta distribution

import .RFunctions:
# R implementations
using .RFunctions:
nbetapdf,
nbetalogpdf,
nbetacdf,
Expand Down
3 changes: 2 additions & 1 deletion src/distrs/nbinom.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# functions related to negative binomial distribution

import .RFunctions:
# R implementations
using .RFunctions:
nbinompdf,
nbinomlogpdf,
nbinomcdf,
Expand Down
3 changes: 2 additions & 1 deletion src/distrs/nchisq.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# functions related to noncentral chi-square distribution

import .RFunctions:
# R implementations
using .RFunctions:
nchisqpdf,
nchisqlogpdf,
nchisqcdf,
Expand Down
3 changes: 2 additions & 1 deletion src/distrs/nfdist.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# functions related to noncentral F distribution

import .RFunctions:
# R implementations
using .RFunctions:
nfdistpdf,
nfdistlogpdf,
nfdistcdf,
Expand Down
3 changes: 2 additions & 1 deletion src/distrs/ntdist.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# functions related to noncentral T distribution

import .RFunctions:
# R implementations
using .RFunctions:
ntdistpdf,
ntdistlogpdf,
ntdistcdf,
Expand Down
26 changes: 9 additions & 17 deletions src/distrs/pois.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# functions related to Poisson distribution

import .RFunctions:
poispdf,
poislogpdf,
# R implementations
# For pdf and logpdf we use the Julia implementation
using .RFunctions:
poiscdf,
poisccdf,
poislogcdf,
Expand All @@ -12,19 +12,11 @@ import .RFunctions:
poisinvlogcdf,
poisinvlogccdf

# generic versions
# Julia implementations
poispdf(λ::Real, x::Real) = exp(poislogpdf(λ, x))

poislogpdf(λ::T, x::T) where {T <: Real} = xlogy(x, λ) - λ - loggamma(x + 1)

poislogpdf(λ::Number, x::Number) = poislogpdf(promote(float(λ), x)...)
nalimilan marked this conversation as resolved.
Show resolved Hide resolved

#=
function poislogpdf(λ::Union{Float32,Float64}, x::Union{Float64,Float32,Integer})
if iszero(λ)
iszero(x) ? zero(λ) : oftype(λ, -Inf)
elseif iszero(x)
else
-lstirling_asym(x + 1)
=#
poislogpdf(λ::Real, x::Real) = poislogpdf(promote(λ, x)...)
function poislogpdf(λ::T, x::T) where {T <: Real}
val = xlogy(x, λ) - λ - loggamma(x + 1)
return x >= 0 && isinteger(x) ? val : oftype(val, -Inf)
end
3 changes: 2 additions & 1 deletion src/distrs/srdist.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# functions related to studentized range distribution

import .RFunctions:
# R implementations
using .RFunctions:
srdistcdf,
srdistccdf,
srdistlogcdf,
Expand Down
17 changes: 10 additions & 7 deletions src/distrs/tdist.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# functions related to student's T distribution

import .RFunctions:
tdistpdf,
tdistlogpdf,
# R implementations
# For pdf and logpdf we use the Julia implementation
using .RFunctions:
tdistcdf,
tdistccdf,
tdistlogcdf,
Expand All @@ -12,8 +12,11 @@ import .RFunctions:
tdistinvlogcdf,
tdistinvlogccdf

# pdf for numbers with generic types
tdistpdf(ν::Real, x::Number) = gamma((ν + 1) / 2) / (sqrt(ν * pi) * gamma(ν / 2)) * (1 + x^2 / ν)^(-(ν + 1) / 2)
# Julia implementations
tdistpdf(ν::Real, x::Real) = exp(tdistlogpdf(ν, x))

# logpdf for numbers with generic types
tdistlogpdf(ν::Real, x::Number) = loggamma((ν + 1) / 2) - log(ν * pi) / 2 - loggamma(ν / 2) + (-(ν + 1) / 2) * log1p(x^2 / ν)
tdistlogpdf(ν::Real, x::Real) = tdistlogpdf(promote(ν, x)...)
function tdistlogpdf(ν::T, x::T) where {T<:Real}
νp12 = (ν + 1) / 2
return loggamma(νp12) - (logπ + log(ν)) / 2 - loggamma(ν / 2) - νp12 * log1p(x^2 / ν)
end
Loading