-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Enhance ETI function * Add eti/eti! to API * Add docstrings to eti * Update summarize docstring * Document ETI in docs * Cross-reference eti from hdi * Add missing export statement
- Loading branch information
Showing
6 changed files
with
163 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,97 @@ | ||
function eti(x::AbstractVecOrMat{<:Real}; prob::Real=DEFAULT_INTERVAL_PROB) | ||
""" | ||
eti(samples::AbstractVecOrMat{<:Real}; [prob, kwargs...]) -> IntervalSets.ClosedInterval | ||
eti(samples::AbstractArray{<:Real}; [prob, kwargs...]) -> Array{<:IntervalSets.ClosedInterval} | ||
Estimate the equal-tailed interval (ETI) of `samples` for the probability `prob`. | ||
The ETI of a given probability is the credible interval wih the property that the | ||
probability of being below the interval is equal to the probability of being above it. | ||
That is, it is defined by the `(1-prob)/2` and `1 - (1-prob)/2` quantiles of the samples. | ||
See also: [`eti!`](@ref), [`hdi`](@ref), [`hdi!`](@ref). | ||
# Arguments | ||
- `samples`: an array of shape `(draws[, chains[, params...]])`. If multiple parameters are | ||
present | ||
# Keywords | ||
- `prob`: the probability mass to be contained in the ETI. Default is | ||
`$(DEFAULT_INTERVAL_PROB)`. | ||
- `kwargs`: remaining keywords are passed to `Statistics.quantile`. | ||
# Returns | ||
- `intervals`: If `samples` is a vector or matrix, then a single | ||
`IntervalSets.ClosedInterval` is returned. Otherwise, an array with the shape | ||
`(params...,)`, is returned, containing a marginal ETI for each parameter. | ||
!!! note | ||
Any default value of `prob` is arbitrary. The default value of | ||
`prob=$(DEFAULT_INTERVAL_PROB)` instead of a more common default like `prob=0.95` is | ||
chosen to reminder the user of this arbitrariness. | ||
# Examples | ||
Here we calculate the 83% ETI for a normal random variable: | ||
```jldoctest eti; setup = :(using Random; Random.seed!(78)) | ||
julia> x = randn(2_000); | ||
julia> eti(x; prob=0.83) | ||
-1.3740585250299766 .. 1.2860771129421198 | ||
``` | ||
We can also calculate the ETI for a 3-dimensional array of samples: | ||
```jldoctest eti; setup = :(using Random; Random.seed!(67)) | ||
julia> x = randn(1_000, 1, 1) .+ reshape(0:5:10, 1, 1, :); | ||
julia> eti(x) | ||
3-element Vector{IntervalSets.ClosedInterval{Float64}}: | ||
-1.951006825019686 .. 1.9011666217153793 | ||
3.048993174980314 .. 6.9011666217153795 | ||
8.048993174980314 .. 11.90116662171538 | ||
``` | ||
""" | ||
function eti( | ||
x::AbstractArray{<:Real}; | ||
prob::Real=DEFAULT_INTERVAL_PROB, | ||
sorted::Bool=false, | ||
kwargs..., | ||
) | ||
return eti!(sorted ? x : _copymutable(x); prob, sorted, kwargs...) | ||
end | ||
|
||
""" | ||
eti!(samples::AbstractArray{<:Real}; [prob, kwargs...]) | ||
A version of [`eti`](@ref) that partially sorts `samples` in-place while computing the ETI. | ||
See also: [`eti`](@ref), [`hdi`](@ref), [`hdi!`](@ref). | ||
""" | ||
function eti!(x::AbstractArray{<:Real}; prob::Real=DEFAULT_INTERVAL_PROB, kwargs...) | ||
ndims(x) > 0 || | ||
throw(ArgumentError("ETI cannot be computed for a 0-dimensional array.")) | ||
0 < prob < 1 || throw(DomainError(prob, "ETI `prob` must be in the range `(0, 1)`.")) | ||
isempty(x) && throw(ArgumentError("ETI cannot be computed for an empty array.")) | ||
return _eti!(x, prob; kwargs...) | ||
end | ||
|
||
function _eti!(x::AbstractVecOrMat{<:Real}, prob::Real; kwargs...) | ||
if any(isnan, x) | ||
T = float(promote_type(eltype(x), typeof(prob))) | ||
lower = upper = T(NaN) | ||
else | ||
alpha = (1 - prob) / 2 | ||
lower, upper = Statistics.quantile(vec(x), (alpha, 1 - alpha)) | ||
lower, upper = Statistics.quantile!(vec(x), (alpha, 1 - alpha); kwargs...) | ||
end | ||
return IntervalSets.ClosedInterval(lower, upper) | ||
end | ||
function _eti!(x::AbstractArray, prob::Real; kwargs...) | ||
axes_out = _param_axes(x) | ||
T = float(promote_type(eltype(x), typeof(prob))) | ||
interval = similar(x, IntervalSets.ClosedInterval{T}, axes_out) | ||
for (i, x_slice) in zip(eachindex(interval), _eachparam(x)) | ||
interval[i] = _eti!(x_slice, prob; kwargs...) | ||
end | ||
return interval | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,87 @@ | ||
using IntervalSets | ||
using OffsetArrays | ||
using PosteriorStats | ||
using Statistics | ||
using Test | ||
|
||
@testset "PosteriorStats.eti" begin | ||
@testset "eti/eti!" begin | ||
@testset "AbstractVecOrMat" begin | ||
@testset for sz in (100, 1_000, (1_000, 2)), | ||
prob in (0.7, 0.76, 0.8, 0.88), | ||
T in (Float32, Float64) | ||
T in (Float32, Float64, Int64) | ||
|
||
S = Base.promote_eltype(one(T), prob) | ||
n = prod(sz) | ||
x = T <: Integer ? rand(T(1):T(30), sz) : randn(T, sz) | ||
r = @inferred PosteriorStats.eti(x; prob) | ||
S = Base.promote_eltype(one(T), prob) | ||
x = T <: Integer ? rand(T(1):T(30), n) : randn(T, n) | ||
r = @inferred eti(x; prob) | ||
@test r isa ClosedInterval{S} | ||
l, u = IntervalSets.endpoints(r) | ||
frac_in_interval = mean(∈(r), x) | ||
@test frac_in_interval ≈ prob | ||
@test count(<(l), x) == count(>(u), x) | ||
if !(T <: Integer) | ||
l, u = IntervalSets.endpoints(r) | ||
frac_in_interval = mean(∈(r), x) | ||
@test frac_in_interval ≈ prob | ||
@test count(<(l), x) == count(>(u), x) | ||
end | ||
|
||
@test eti!(copy(x); prob) == r | ||
end | ||
end | ||
|
||
@testset "edge cases and errors" begin | ||
@testset "NaNs returned if contains NaNs" begin | ||
x = randn(1000) | ||
x[3] = NaN | ||
@test isequal(PosteriorStats.eti(x), NaN .. NaN) | ||
@test isequal(eti(x), NaN .. NaN) | ||
end | ||
|
||
@testset "errors for empty array" begin | ||
x = Float64[] | ||
@test_throws ArgumentError PosteriorStats.eti(x) | ||
@test_throws ArgumentError eti(x) | ||
end | ||
|
||
@testset "errors for 0-dimensional array" begin | ||
x = fill(1.0) | ||
@test_throws ArgumentError eti(x) | ||
end | ||
|
||
@testset "test errors when prob is not in (0, 1)" begin | ||
x = randn(1_000) | ||
@testset for prob in (0, 1, -0.1, 1.1, NaN) | ||
@test_throws DomainError PosteriorStats.eti(x; prob) | ||
@test_throws DomainError eti(x; prob) | ||
end | ||
end | ||
end | ||
|
||
@testset "AbstractArray consistent with AbstractVector" begin | ||
@testset for sz in ((100, 2), (100, 2, 3), (100, 2, 3, 4)), | ||
prob in (0.72, 0.81), | ||
T in (Float32, Float64, Int64) | ||
|
||
x = T <: Integer ? rand(T(1):T(30), sz) : randn(T, sz) | ||
r = @inferred eti(x; prob) | ||
if ndims(x) == 2 | ||
@test r isa ClosedInterval | ||
@test r == eti(vec(x); prob) | ||
else | ||
@test r isa Array{<:ClosedInterval,ndims(x) - 2} | ||
r_slices = dropdims( | ||
mapslices(x -> eti(x; prob), x; dims=(1, 2)); dims=(1, 2) | ||
) | ||
@test r == r_slices | ||
end | ||
|
||
@test eti!(copy(x); prob) == r | ||
end | ||
end | ||
|
||
@testset "OffsetArray" begin | ||
@testset for n in (100, 1_000), prob in (0.732, 0.864), T in (Float32, Float64) | ||
x = randn(T, (n, 2, 3, 4)) | ||
xoff = OffsetArray(x, (-1, 2, -3, 4)) | ||
r = eti(x; prob) | ||
roff = @inferred eti(xoff; prob) | ||
@test roff isa OffsetMatrix{<:ClosedInterval} | ||
@test axes(roff) == (axes(xoff, 3), axes(xoff, 4)) | ||
@test collect(roff) == r | ||
end | ||
end | ||
end |