Skip to content

Commit

Permalink
Merge pull request #71 from TuringLang/csp/serial
Browse files Browse the repository at this point in the history
Add MCMCSerial
  • Loading branch information
cpfiffer authored May 13, 2021
2 parents a089476 + 48ba9b0 commit 686a341
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "3.1.0"
version = "3.2.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
6 changes: 3 additions & 3 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ AbstractMCMC.sample(
::AbstractRNG,
::AbstractMCMC.AbstractModel,
::AbstractMCMC.AbstractSampler,
::AbstractMCMC.AbstractMCMCParallel,
::AbstractMCMC.AbstractMCMCEnsemble,
::Integer,
::Integer,
)
```

Two algorithms are provided for parallel sampling with multiple threads and multiple processes,
respectively:
Two algorithms are provided for parallel sampling with multiple threads and multiple processes, and one allows for the user to sample multiple chains in serial (no parallelization):
```@docs
AbstractMCMC.MCMCThreads
AbstractMCMC.MCMCDistributed
AbstractMCMC.MCMCSerial
```

## Common keyword arguments
Expand Down
25 changes: 17 additions & 8 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using StatsBase: sample
export sample

# Parallel sampling types
export MCMCThreads, MCMCDistributed
export MCMCThreads, MCMCDistributed, MCMCSerial

"""
AbstractChains
Expand Down Expand Up @@ -48,34 +48,43 @@ An `AbstractModel` represents a generic model type that can be used to perform i
abstract type AbstractModel end

"""
AbstractMCMCParallel
AbstractMCMCEnsemble
An `AbstractMCMCParallel` algorithm represents a specific algorithm for sampling MCMC chains
An `AbstractMCMCEnsemble` algorithm represents a specific algorithm for sampling MCMC chains
in parallel.
"""
abstract type AbstractMCMCParallel end
abstract type AbstractMCMCEnsemble end

"""
MCMCThreads
The `MCMCThreads` algorithm allows to sample MCMC chains in parallel using multiple
The `MCMCThreads` algorithm allows users to sample MCMC chains in parallel using multiple
threads.
"""
struct MCMCThreads <: AbstractMCMCParallel end
struct MCMCThreads <: AbstractMCMCEnsemble end

"""
MCMCDistributed
The `MCMCDistributed` algorithm allows to sample MCMC chains in parallel using multiple
The `MCMCDistributed` algorithm allows users to sample MCMC chains in parallel using multiple
processes.
"""
struct MCMCDistributed <: AbstractMCMCParallel end
struct MCMCDistributed <: AbstractMCMCEnsemble end


"""
MCMCSerial
The `MCMCSerial` algorithm allows users to sample serially, with no thread or process parallelism.
"""
struct MCMCSerial <: AbstractMCMCEnsemble end

include("samplingstats.jl")
include("logging.jl")
include("interface.jl")
include("sample.jl")
include("stepper.jl")
include("transducer.jl")
include("deprecations.jl")

end # module AbstractMCMC
2 changes: 2 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Deprecate the old name AbstractMCMCParallel in favor of AbstractMCMCEnsemble
Base.@deprecate_binding AbstractMCMCParallel AbstractMCMCEnsemble false
30 changes: 28 additions & 2 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ end
function StatsBase.sample(
model::AbstractModel,
sampler::AbstractSampler,
parallel::AbstractMCMCParallel,
parallel::AbstractMCMCEnsemble,
N::Integer,
nchains::Integer;
kwargs...
Expand All @@ -80,7 +80,7 @@ function StatsBase.sample(
rng::Random.AbstractRNG,
model::AbstractModel,
sampler::AbstractSampler,
parallel::AbstractMCMCParallel,
parallel::AbstractMCMCEnsemble,
N::Integer,
nchains::Integer;
kwargs...
Expand Down Expand Up @@ -444,5 +444,31 @@ function mcmcsample(
return chainsstack(tighten_eltype(chains))
end

function mcmcsample(
rng::Random.AbstractRNG,
model::AbstractModel,
sampler::AbstractSampler,
::MCMCSerial,
N::Integer,
nchains::Integer;
progressname = "Sampling",
kwargs...
)
# Check if the number of chains is larger than the number of samples
if nchains > N
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Sample the chains.
chains = map(
i -> StatsBase.sample(rng, model, sampler, N; progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"),
kwargs...),
1:nchains
)

# Concatenate the chains together.
return chainsstack(tighten_eltype(chains))
end

tighten_eltype(x) = x
tighten_eltype(x::Vector{Any}) = map(identity, x)
46 changes: 46 additions & 0 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,52 @@
@test all(l.level > Logging.LogLevel(-1) for l in logs)
end

@testset "Serial sampling" begin
# No dedicated chains type
N = 10_000
chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000)
@test chains isa Vector{<:Vector{<:MySample}}
@test length(chains) == 1000
@test all(length(x) == N for x in chains)

Random.seed!(1234)
chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000;
chain_type = MyChain)

# Test output type and size.
@test chains isa Vector{<:MyChain}
@test all(c.as[1] === missing for c in chains)
@test length(chains) == 1000
@test all(x -> length(x.as) == length(x.bs) == N, chains)

# Test some statistical properties.
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)

# Test reproducibility.
Random.seed!(1234)
chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000;
chain_type = MyChain)

@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)

# Unexpected order of arguments.
str = "Number of chains (10) is greater than number of samples per chain (5)"
@test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(),
MCMCSerial(), 5, 10;
chain_type = MyChain)

# Suppress output.
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
sample(MyModel(), MySampler(), MCMCSerial(), 10_000, 100;
progress = false, chain_type = MyChain)
end
@test all(l.level > Logging.LogLevel(-1) for l in logs)
end

@testset "Chain constructors" begin
chain1 = sample(MyModel(), MySampler(), 100; sleepy = true)
chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain)
Expand Down

2 comments on commit 686a341

@cpfiffer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/36712

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.2.0 -m "<description of version>" 686a341ebba9d4a5d5d70961e7c4ccf59656936b
git push origin v3.2.0

Please sign in to comment.