Skip to content

Commit

Permalink
Merge pull request #60 from TuringLang/dw/threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
cpfiffer authored Mar 19, 2021
2 parents 0ebfcf5 + 77807d0 commit 1474452
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "2.3.0"
version = "2.4.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
63 changes: 47 additions & 16 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,23 @@ function mcmcsample(
Ntotal = thinning * (N - 1) + discard_initial + 1

@ifwithprogresslogger progress name=progressname begin
# Determine threshold values for progress logging
# (one update per 0.5% of progress)
if progress
threshold = Ntotal ÷ 200
next_update = threshold
end

# Obtain the initial sample and state.
sample, state = step(rng, model, sampler; kwargs...)

# Discard initial samples.
for i in 1:(discard_initial - 1)
# Update the progress bar.
progress && ProgressLogging.@logprogress i/Ntotal
if progress && i >= next_update
ProgressLogging.@logprogress i/Ntotal
next_update = i + threshold
end

# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
Expand All @@ -106,20 +116,23 @@ function mcmcsample(
samples = save!!(samples, sample, 1, model, sampler, N; kwargs...)

# Update the progress bar.
progress && ProgressLogging.@logprogress (1 + discard_initial) / Ntotal
itotal = 1 + discard_initial
if progress && itotal >= next_update
ProgressLogging.@logprogress itotal / Ntotal
next_update = itotal + threshold
end

# Step through the sampler.
itotal = 1 + discard_initial
for i in 2:N
# Discard thinned samples.
for _ in 1:(thinning - 1)
# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)

# Update progress bar.
if progress
itotal += 1
if progress && (itotal += 1) >= next_update
ProgressLogging.@logprogress itotal / Ntotal
next_update = itotal + threshold
end
end

Expand All @@ -133,9 +146,9 @@ function mcmcsample(
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)

# Update the progress bar.
if progress
itotal += 1
if progress && (itotal += 1) >= next_update
ProgressLogging.@logprogress itotal / Ntotal
next_update = itotal + threshold
end
end
end
Expand Down Expand Up @@ -263,17 +276,25 @@ function mcmcsample(
@ifwithprogresslogger progress name=progressname begin
# Create a channel for progress logging.
if progress
channel = Distributed.RemoteChannel(() -> Channel{Bool}(nchains))
channel = Channel{Bool}(length(interval))
end

Distributed.@sync begin
if progress
# Update the progress bar.
Distributed.@async begin
# Update the progress bar.
# Determine threshold values for progress logging
# (one update per 0.5% of progress)
threshold = nchains ÷ 200
nextprogresschains = threshold

progresschains = 0
while take!(channel)
progresschains += 1
ProgressLogging.@logprogress progresschains/nchains
if progresschains >= nextprogresschains
ProgressLogging.@logprogress progresschains/nchains
nextprogresschains = progresschains + threshold
end
end
end
end
Expand Down Expand Up @@ -334,19 +355,29 @@ function mcmcsample(
# Set up worker pool.
pool = Distributed.CachingPool(Distributed.workers())

# Create a channel for progress logging.
channel = progress ? Distributed.RemoteChannel(() -> Channel{Bool}(nchains)) : nothing

local chains
@ifwithprogresslogger progress name=progressname begin
# Create a channel for progress logging.
if progress
channel = Distributed.RemoteChannel(() -> Channel{Bool}(Distributed.nworkers()))
end

Distributed.@sync begin
# Update the progress bar.
if progress
# Update the progress bar.
Distributed.@async begin
# Determine threshold values for progress logging
# (one update per 0.5% of progress)
threshold = nchains ÷ 200
nextprogresschains = threshold

progresschains = 0
while take!(channel)
progresschains += 1
ProgressLogging.@logprogress progresschains/nchains
if progresschains >= nextprogresschains
ProgressLogging.@logprogress progresschains/nchains
nextprogresschains = progresschains + threshold
end
end
end
end
Expand All @@ -362,7 +393,7 @@ function mcmcsample(
progress = false, kwargs...)

# Update the progress bar.
channel === nothing || put!(channel, true)
progress && put!(channel, true)

# Return the new chain.
return chain
Expand Down

2 comments on commit 1474452

@cpfiffer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/32372

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.4.0 -m "<description of version>" 1474452d60452ddd3cc2748b1f3d3099147c7a9a
git push origin v2.4.0

Please sign in to comment.