Skip to content

Commit

Permalink
Promote arguments of norminvcdf and norminvccdf (#132)
Browse files Browse the repository at this point in the history
* Promote arguments of `norminvcdf` and `norminvccdf`

* Simplify promotions

* Improve clarity of tests

Co-authored-by: Seth Axen <seth.axen@gmail.com>

* Simplify tests

* Add method with `rtol`

* Simplification

Co-authored-by: Seth Axen <seth.axen@gmail.com>
  • Loading branch information
devmotion and sethaxen authored Jan 24, 2022
1 parent df291cc commit 13e231a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StatsFuns"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.14"
version = "0.9.15"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
26 changes: 15 additions & 11 deletions src/distrs/norm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ function normlogpdf(μ::Real, σ::Real, x::Number)
z = zval(μ, σ, x)
end
normlogpdf(z) - log(σ)
end
end

# cdf
normcdf(z::Number) = erfc(-z * invsqrt2)/2
function normcdf::Real, σ::Real, x::Number)
if iszero(σ) && x == μ
z = zval(zero(μ), σ, one(x))
else
z = zval(μ, σ, x)
else
z = zval(μ, σ, x)
end
normcdf(z)
end
Expand All @@ -56,8 +56,8 @@ normccdf(z::Number) = erfc(z * invsqrt2)/2
function normccdf::Real, σ::Real, x::Number)
if iszero(σ) && x == μ
z = zval(zero(μ), σ, one(x))
else
z = zval(μ, σ, x)
else
z = zval(μ, σ, x)
end
normccdf(z)
end
Expand All @@ -69,8 +69,8 @@ normlogcdf(z::Number) = z < -1.0 ?
function normlogcdf::Real, σ::Real, x::Number)
if iszero(σ) && x == μ
z = zval(zero(μ), σ, one(x))
else
z = zval(μ, σ, x)
else
z = zval(μ, σ, x)
end
normlogcdf(z)
end
Expand All @@ -82,17 +82,21 @@ normlogccdf(z::Number) = z > 1.0 ?
function normlogccdf::Real, σ::Real, x::Number)
if iszero(σ) && x == μ
z = zval(zero(μ), σ, one(x))
else
z = zval(μ, σ, x)
else
z = zval(μ, σ, x)
end
normlogccdf(z)
end

norminvcdf(p::Real) = -erfcinv(2*p) * sqrt2
norminvcdf::Real, σ::Real, p::Real) = xval(μ, σ, norminvcdf(p))
# Promote to ensure that we don't compute erfcinv in low precision and then promote
norminvcdf::Real, σ::Real, p::Real) = norminvcdf(promote(μ, σ, p)...)
norminvcdf::T, σ::T, p::T) where {T<:Real} = xval(μ, σ, norminvcdf(p))

norminvccdf(p::Real) = erfcinv(2*p) * sqrt2
norminvccdf::Real, σ::Real, p::Real) = xval(μ, σ, norminvccdf(p))
# Promote to ensure that we don't compute erfcinv in low precision and then promote
norminvccdf::Real, σ::Real, p::Real) = norminvccdf(promote(μ, σ, p)...)
norminvccdf::T, σ::T, p::T) where {T<:Real} = xval(μ, σ, norminvccdf(p))

# invlogcdf. Fixme! Support more precisions than Float64
norminvlogcdf(lp::Union{Float16,Float32}) = convert(typeof(lp), _norminvlogcdf_impl(Float64(lp)))
Expand Down
12 changes: 11 additions & 1 deletion test/rmath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@ end
get_statsfun(fname) = eval(Symbol(fname))
get_rmathfun(fname) = eval(Meta.parse(string("RFunctions.", fname)))

function rmathcomp(basename, params, X::AbstractArray, rtol=100eps(float(one(eltype(X)))))
function rmathcomp(basename, params, X::AbstractArray)
# compute default tolerance:
# has to take into account `params` as well since otherwise e.g. `X::Array{<:Rational}`
# always uses a tolerance based on `eps(one(Float64))` even when parameters are of type
# Float32
rtol = 100 * eps(float(one(promote_type(Base.promote_typeof(params...), eltype(X)))))
rmathcomp(basename, params, X, rtol)
end
function rmathcomp(basename, params, X::AbstractArray, rtol)
# tackle pdf specially
has_pdf = true
if basename == "srdist"
Expand Down Expand Up @@ -264,6 +272,8 @@ end
((0.0, 0.0), -6.0:0.1:6.0),
((0f0, 1f0), -6f0:0.01f0:6f0),
((0.0, 1.0), -6f0:0.01f0:6f0),
((0, 2), -6//1:1//2:6//1),
((0f0, 2f0), -6//1:1//2:6//1),
# Fail since `SpecialFunctions.erfcx` is not implemented for `Float16`
#((Float16(0), Float16(1)), -Float16(6):Float16(0.01):Float16(6)),
#((0f0, 1f0), -Float16(6):Float16(0.01):Float16(6)),
Expand Down

2 comments on commit 13e231a

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/53112

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.15 -m "<description of version>" 13e231a0a22e716426b73cb87ff3b8b24e33aaf1
git push origin v0.9.15

Please sign in to comment.