Skip to content

Commit

Permalink
Merge pull request #271 from Julia-Tempering/fix-inv-test
Browse files Browse the repository at this point in the history
Invariance test: switch to 2 indep samples instead of paired design
  • Loading branch information
alexandrebouchard authored Aug 15, 2024
2 parents a0e764f + abf6489 commit 40cdeab
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 12 deletions.
12 changes: 7 additions & 5 deletions ext/PigeonsDynamicPPLExt/invariance_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ function Pigeons.forward_sample_condition_and_explore(
model::DynamicPPL.Model,
explorer,
rng::SplittableRandom;
run_explorer::Bool = true,
condition_on::NTuple{N,Symbol}
) where {N}
# forward simulation
Expand All @@ -35,12 +36,13 @@ function Pigeons.forward_sample_condition_and_explore(
state = DynamicPPL.TypedVarInfo(cond_vi)
DynamicPPL.link!!(state, DynamicPPL.SampleFromPrior(), conditioned_model)

# record starting values and then take a step with explorer
init_values = DynamicPPL.getall(state)
final_state = Pigeons.explorer_step(rng, TuringLogPotential(conditioned_model), explorer, state)
# maybe take a step with explorer
if run_explorer
state = Pigeons.explorer_step(rng, TuringLogPotential(conditioned_model), explorer, state)
end

# return initial and final values
return (;init_values=init_values, final_values=DynamicPPL.getall(final_state))
# return a flattened version of state
return DynamicPPL.getall(state)
end

Pigeons.forward_sample_condition_and_explore(target::TuringLogPotential, args...; kwargs...) =
Expand Down
8 changes: 5 additions & 3 deletions ext/PigeonsHypothesisTestsExt/PigeonsHypothesisTestsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ function Pigeons.invariance_test(

# iterate iid samples
for n in eachindex(initial_values)
inits, finals = Pigeons.forward_sample_condition_and_explore(target, explorer, rng; simulator_kwargs...)
initial_values[n] = inits
final_values[n] = finals
initial_values[n] = Pigeons.forward_sample_condition_and_explore(
target, explorer, rng; run_explorer=false, simulator_kwargs...)
final_values[n] = Pigeons.forward_sample_condition_and_explore(
target, explorer, rng; simulator_kwargs...)
end

# transform vector of vectors to matrices so that iterating dimensions == iterating columns => faster
Expand Down Expand Up @@ -64,4 +65,5 @@ function Pigeons.invariance_test(
return (;passed=passed, pvalues=pvalues, failed_tests=failed_tests)
end


end # End module
30 changes: 28 additions & 2 deletions src/explorers/invariance_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ The workhorse under [`invariance_test`](@ref). It starts with a full forward pas
for the probabilistic model underlying `target`, thats simulates latent variables and
observations. Then a modified model is created that conditions the original model
on the observations produced. Finally, the function takes a step using the explorer
targetting the conditioned model. The function returns both pre- and post-exploration
states.
targetting the conditioned model and the final state is returned. The exploration
can be optionally disabled by passing `run_explorer=false`, in which case the
initial simulated state is returned.
"""
function forward_sample_condition_and_explore end

Expand All @@ -49,3 +50,28 @@ function explorer_step(rng::SplittableRandom, target, explorer, init_state)
Pigeons.step!(explorer, replica, shared)
return replica.state
end


#=
Implementations of forward_sample_condition_and_explore for Pigeons' toy targets
that allow forward simulation
=#

"""
$SIGNATURES
Implementation for [`ScaledPrecisionNormalPath`](@ref). Since this toy model
allows direct iid sampling from the target, conditioning is not necessary.
"""
function forward_sample_condition_and_explore(
target::ScaledPrecisionNormalPath,
explorer,
rng::SplittableRandom;
run_explorer::Bool = true
)
state = initialization(target, rng, 1) # forward simulation
if run_explorer
state = explorer_step(rng, target, explorer, state)
end
return state
end
2 changes: 1 addition & 1 deletion src/includes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ include("explorers/Mix.jl")
include("explorers/Preconditioner.jl")
include("explorers/MALA.jl")
include("explorers/AutoMALA.jl")
include("explorers/invariance_test.jl")
include("explorers/Compose.jl")
include("explorers/Augmentation.jl")
include("targets/DistributionLogPotential.jl")
Expand All @@ -82,6 +81,7 @@ include("explorers/BufferedAD.jl")
include("variational/GaussianReference.jl")
include("variational/VariationalReference.jl")
include("paths/ScaledPrecisionNormalPath.jl")
include("explorers/invariance_test.jl")
include("targets/toy_mvn_target.jl")
include("explorers/AAPS.jl")
include("explorers/GradientBasedSampler.jl")
Expand Down
2 changes: 1 addition & 1 deletion test/test_invariance_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ using HypothesisTests
Pigeons.invariance_test(target, IdentityExplorer(), rng; condition_on=(:n_successes,))
end
@test res.passed
@test all(==(1), res.pvalues)
end

@testset "Test a true positive" begin
Expand All @@ -49,6 +48,7 @@ using HypothesisTests
for explorer in explorers
@show explorer
@test first(Pigeons.invariance_test(target, explorer, rng; condition_on=(:n_successes,)))
@test first(Pigeons.invariance_test(toy_mvn_target(2), explorer, rng))
end
end
end

0 comments on commit 40cdeab

Please sign in to comment.