Skip to content

Commit

Permalink
MH Constructor (#2037)
Browse files Browse the repository at this point in the history
* first draft

* abstractcontext + tests

* bug

* externalsampler() in tests

* Name Tupple problems

* moving stuff to DynamicPPL RP

* using new DynamicPPL PR

* mistakenly removed line

* specific constructors

* no StaticMH RWMH

* Bump bijectors compat (#2052)

* CompatHelper: bump compat for Bijectors to 0.13, (keep existing compat)

* Update Project.toml

* Replacement for #2039 (#2040)

* Fix testset for external samplers

* Update abstractmcmc.jl

* Update test/contrib/inference/abstractmcmc.jl

Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com>

* Update test/contrib/inference/abstractmcmc.jl

Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com>

* Update FillArrays compat to 1.4.1 (#2035)

* Update FillArrays compat to 1.4.0

* Update test compat

* Try to enable ReverseDiff tests

* Update Project.toml

* Update Project.toml

* Bump version

* Revert dependencies on FillArrays (#2042)

* Update Project.toml

* Update Project.toml

* Fix redundant definition of `getstats` (#2044)

* Fix redundant definition of `getstats`

* Update Inference.jl

* Revert "Update Inference.jl"

This reverts commit e4f51c2.

* Bump version

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>

* Transfer some test utility function into DynamicPPL (#2049)

* Update OptimInterface.jl

* Only run optimisation tests in numerical stage.

* fix function lookup after moving functions

---------

Co-authored-by: Xianda Sun <sunxdt@gmail.com>

* Move Optim support to extension (#2051)

* Move Optim support to extension

* More imports

* Update Project.toml

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>

---------

Co-authored-by: CompatHelper Julia <compathelper_noreply@julialang.org>
Co-authored-by: haris organtzidis <organtzh@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: Xianda Sun <sunxdt@gmail.com>
Co-authored-by: Cameron Pfiffer <cpfiffer@gmail.com>

* Bugfixes.

* Add TODO.

* Update mh.jl

* Update Inference.jl

* Removed obsolete exports.

* removed unnecessary import of extract_priors

* added missing ) in MH tests

* fixed incorrect referneces to AdvancedMH in tests

* improve ESLogDensityFunction

* remove hardcoding of SimpleVarInfo

* added fixme comment

* minor style changes

* fixed issues with MH with RandomWalkProposal being used as an external sampler

* fixed accidental typo

* move definitions of unflatten for NamedTuple

* improved TODO

* Update Project.toml

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
Co-authored-by: CompatHelper Julia <compathelper_noreply@julialang.org>
Co-authored-by: haris organtzidis <organtzh@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: Xianda Sun <sunxdt@gmail.com>
Co-authored-by: Cameron Pfiffer <cpfiffer@gmail.com>
Co-authored-by: Hong Ge <hg344@cam.ac.uk>
  • Loading branch information
9 people authored Aug 16, 2023
1 parent d8beaf0 commit 4affc28
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 1 deletion.
1 change: 0 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ export @model, # modelling
Prior, # Sampling from the prior

MH, # classic sampling
RWMH,
Emcee,
ESS,
Gibbs,
Expand Down
18 changes: 18 additions & 0 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,24 @@ Wrap a sampler so it can be used as an inference algorithm.
"""
externalsampler(sampler::AbstractSampler) = ExternalSampler(sampler)

"""
ESLogDensityFunction
A log density function for the External sampler.
"""
const ESLogDensityFunction{M<:Model,S<:Sampler{<:ExternalSampler},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,<:DynamicPPL.DefaultContext}
function LogDensityProblems.logdensity(f::ESLogDensityFunction, x::NamedTuple)
return DynamicPPL.logjoint(f.model, DynamicPPL.unflatten(f.varinfo, x))
end

# TODO: make a nicer `set_namedtuple!` and move these functions to DynamicPPL.
function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple)
set_namedtuple!(deepcopy(vi), θ)
return vi
end
DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple) = SimpleVarInfo(θ, vi.logp, vi.transformation)

# Algorithm for sampling from the prior
struct Prior <: InferenceAlgorithm end

Expand Down
15 changes: 15 additions & 0 deletions src/inference/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,20 @@ function MH(space...)
return MH{tuple(syms...), typeof(proposals)}(proposals)
end

# Some of the proposals require working in unconstrained space.
transform_maybe(proposal::AMH.Proposal) = proposal
function transform_maybe(proposal::AMH.RandomWalkProposal)
return AMH.RandomWalkProposal(Bijectors.transformed(proposal.proposal))
end

function MH(model::Model; proposal_type=AMH.StaticProposal)
priors = DynamicPPL.extract_priors(model)
props = Tuple([proposal_type(prop) for prop in values(priors)])
vars = Tuple(map(Symbol, collect(keys(priors))))
priors = map(transform_maybe, NamedTuple{vars}(props))
return AMH.MetropolisHastings(priors)
end

#####################
# Utility functions #
#####################
Expand Down Expand Up @@ -346,6 +360,7 @@ end
function should_link(varinfo, sampler, proposal::AdvancedMH.RandomWalkProposal)
return true
end
# FIXME: This won't be hit unless `vals` are all the exactly same concrete type of `AdvancedMH.RandomWalkProposal`!
function should_link(
varinfo,
sampler,
Expand Down
6 changes: 6 additions & 0 deletions test/inference/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

s4 = Gibbs(MH(:m), MH(:s))
c4 = sample(gdemo_default, s4, N)

s5 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.RandomWalkProposal))
c5 = sample(gdemo_default, s5, N)

s6 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.StaticProposal))
c6 = sample(gdemo_default, s6, N)
end
@numerical_testset "mh inference" begin
Random.seed!(125)
Expand Down

2 comments on commit 4affc28

@yebai
Copy link
Member

@yebai yebai commented on 4affc28 Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/89794

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.28.2 -m "<description of version>" 4affc28b341f4763bd1abc8523e4e209e9f6aa6e
git push origin v0.28.2

Please sign in to comment.