Skip to content

Commit

Permalink
Merge pull request #86 from SciML/distributionsext
Browse files Browse the repository at this point in the history
Make Distributions.jl into a weakdep
  • Loading branch information
ChrisRackauckas authored Sep 9, 2023
2 parents 537e1fd + d57cd68 commit 9815576
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 35 deletions.
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@ version = "0.3.2"
[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LatticeRules = "73f95e8e-ec14-4e6a-8b18-0d2e271c4e55"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Sobol = "ed01d8cd-4d21-5b2a-85b4-cc3bdc58bad4"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

[extensions]
QuasiMonteCarloDistributionsExt = "Distributions"

[compat]
Accessors = "0.1"
ConcreteStructs = "0.2"
Expand All @@ -25,6 +31,7 @@ StatsBase = "0.33, 0.34"
julia = "1.6"

[extras]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
Expand Down
72 changes: 72 additions & 0 deletions ext/QuasiMonteCarloDistributionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
module QuasiMonteCarloDistributionsExt

using QuasiMonteCarlo
isdefined(Base, :get_extension) ? (import Distributions) : (import ..Distributions)

"""
```julia
sample(n::Integer, lb::T, ub::T, D::Distributions.Sampleable, T = eltype(D))
sample(n::Integer,
lb::T,
ub::T,
D::Distributions.Sampleable) where {T <: Union{Base.AbstractVecOrTuple, Number}}
```
Return a point set from a distribution `D`:
- `n` is the number of points to sample.
- `D` is a `Distributions.Sampleable` from Distributions.jl.
The point set is in a `d`-dimensional unit box `[0, 1]^d`.
If the bounds are specified instead of just `d`, the sample is transformed (translation + scaling) into a box `[lb, ub]` where:
- `lb` is the lower bound for each variable. Its length fixes the dimensionality of the sample.
- `ub` is the upper bound. Its dimension must match `length(lb)`.
"""
function QuasiMonteCarlo.sample(n::Integer, d::Integer, D::Distributions.Sampleable, T = eltype(D))
@assert n>0 QuasiMonteCarlo.ZERO_SAMPLES_MESSAGE
x = [[rand(D) for j in 1:d] for i in 1:n]
return reduce(hcat, x)
end

"""
```julia
sample(n::Integer, d::Integer, S::Distributions.Sampleable, T = Float64)
sample(n::Integer,
lb::T,
ub::T,
S::Distributions.Sampleable) where {T <: Union{Base.AbstractVecOrTuple, Number}}
```
Return a QMC point set where:
- `n` is the number of points to sample.
- `S` is the quasi-Monte Carlo sampling strategy.
The point set is in a `d`-dimensional unit box `[0, 1]^d`.
If the bounds are specified, the sample is transformed (translation + scaling) into a box `[lb, ub]` where:
- `lb` is the lower bound for each variable. Its length fixes the dimensionality of the sample.
- `ub` is the upper bound. Its dimension must match `length(lb)`.
In the first method the type of the point set is specified by `T` while in the second method the output type is infered from the bound types.
"""
function QuasiMonteCarlo.sample(n::Integer, lb::T, ub::T,
S::D) where {T <: Union{Base.AbstractVecOrTuple, Number},
D <: Distributions.Sampleable}
QuasiMonteCarlo._check_sequence(lb, ub, n)
lb = float.(lb)
ub = float.(ub)
out = QuasiMonteCarlo.sample(n, length(lb), S, eltype(lb))
return (ub .- lb) .* out .+ lb
end

function QuasiMonteCarlo.DesignMatrix(N, d, D::Distributions.Sampleable, num_mats, T = Float64)
X = QuasiMonteCarlo.initialize(N, d, D, T)
return QuasiMonteCarlo.DistributionDesignMat(X, D, num_mats)
end

function QuasiMonteCarlo.initialize(n, d, D::Distributions.Sampleable, T = Float64)
# Generate unrandomized sequence
X = zeros(T, d, n)
return X
end


end
35 changes: 9 additions & 26 deletions src/QuasiMonteCarlo.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module QuasiMonteCarlo

using Sobol, LatticeRules, Distributions, Primes, LinearAlgebra, Random
using Sobol, LatticeRules, Primes, LinearAlgebra, Random
using ConcreteStructs, Accessors

abstract type SamplingAlgorithm end
Expand Down Expand Up @@ -56,38 +56,14 @@ In the first method the type of the point set is specified by `T` while in the s
"""
function sample(n::Integer, lb::T, ub::T,
S::D) where {T <: Union{Base.AbstractVecOrTuple, Number},
D <: Union{SamplingAlgorithm, Distributions.Sampleable}}
D <: SamplingAlgorithm}
_check_sequence(lb, ub, n)
lb = float.(lb)
ub = float.(ub)
out = sample(n, length(lb), S, eltype(lb))
return (ub .- lb) .* out .+ lb
end

"""
```julia
sample(n::Integer, lb::T, ub::T, D::Distributions.Sampleable, T = eltype(D))
sample(n::Integer,
lb::T,
ub::T,
D::Distributions.Sampleable) where {T <: Union{Base.AbstractVecOrTuple, Number}}
```
Return a point set from a distribution `D`:
- `n` is the number of points to sample.
- `D` is a `Distributions.Sampleable` from Distributions.jl.
The point set is in a `d`-dimensional unit box `[0, 1]^d`.
If the bounds are specified instead of just `d`, the sample is transformed (translation + scaling) into a box `[lb, ub]` where:
- `lb` is the lower bound for each variable. Its length fixes the dimensionality of the sample.
- `ub` is the upper bound. Its dimension must match `length(lb)`.
"""
function sample(n::Integer, d::Integer, D::Distributions.Sampleable, T = eltype(D))
@assert n>0 ZERO_SAMPLES_MESSAGE
x = [[rand(D) for j in 1:d] for i in 1:n]
return reduce(hcat, x)
end

# See https://discourse.julialang.org/t/is-there-a-dedicated-function-computing-m-int-log-b-b-m/89776/10
function logi(b::Integer, n::Integer)
m = round(Int, log(b, n))
Expand Down Expand Up @@ -119,6 +95,13 @@ include("RandomizedQuasiMonteCarlo/shifting.jl")
include("RandomizedQuasiMonteCarlo/scrambling_base_b.jl")
include("RandomizedQuasiMonteCarlo/iterators.jl")

import Requires
@static if !isdefined(Base, :get_extension)
function __init__()
Requires.@require Distributions="31c24e10-a181-5473-b8eb-7969acd0382f" begin include("../ext/QuasiMonteCarloDistributionsExt.jl") end
end
end

export SamplingAlgorithm,
GridSample,
SobolSample,
Expand Down
11 changes: 3 additions & 8 deletions src/RandomizedQuasiMonteCarlo/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ Create an iterator for mutiple distribution randomization. The distribution is c
This is equivalent to using `rand!(D, X)` for some matrix `X`.
One can use the commun [`DesignMatrix`](@ref) interface to create the iterator.
"""
mutable struct DistributionDesignMat{T<:Real} <: AbstractDesignMatrix
mutable struct DistributionDesignMat{T<:Real, T2} <: AbstractDesignMatrix
X::Array{T,2} #array of size (N, d)
D::Distributions.Sampleable
D::T2#::Distributions.Sampleable
count::Int
end

Expand Down Expand Up @@ -101,11 +101,6 @@ function DesignMatrix(N, d, S::DeterministicSamplingAlgorithm, R::Shift, num_mat
return ShiftDesignMat(X, R, num_mats)
end

function DesignMatrix(N, d, D::Distributions.Sampleable, num_mats, T = Float64)
X = initialize(N, d, D, T)
return DistributionDesignMat(X, D, num_mats)
end

function DesignMatrix(N, d, D::RandomSample, num_mats, T = Float64)
X = initialize(N, d, D, T)
return RandomDesignMat(X, num_mats)
Expand Down Expand Up @@ -170,7 +165,7 @@ function initialize(n, d, sampler, R::Shift, T = Float64)
end

## Distribution
function initialize(n, d, D::Union{Distributions.Sampleable,RandomSample}, T = Float64)
function initialize(n, d, D::RandomSample, T = Float64)
# Generate unrandomized sequence
X = zeros(T, d, n)
return X
Expand Down

0 comments on commit 9815576

Please sign in to comment.