Skip to content

Commit

Permalink
Reduce iterations in HMC tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Dec 6, 2024
1 parent a8b5cdd commit 320bbd0
Showing 1 changed file with 33 additions and 37 deletions.
70 changes: 33 additions & 37 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ using Test: @test, @test_logs, @testset, @test_throws
using Turing

@testset "Testing hmc.jl with $adbackend" for adbackend in ADUtils.adbackends
# Set a seed
@info "Starting HMC tests with $adbackend"
rng = StableRNG(123)

@testset "constrained bounded" begin
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]

Expand All @@ -36,11 +37,12 @@ using Turing
rng,
constrained_test(obs),
HMC(1.5, 3; adtype=adbackend),# using a large step size (1.5)
1000,
500,
)

check_numerical(chain, [:p], [10 / 14]; atol=0.1)
end

@testset "constrained simplex" begin
obs12 = [1, 2, 1, 2, 2, 2, 2, 2, 2, 2]

Expand All @@ -54,23 +56,25 @@ using Turing
end

chain = sample(
rng, constrained_simplex_test(obs12), HMC(0.75, 2; adtype=adbackend), 1000
rng, constrained_simplex_test(obs12), HMC(0.75, 2; adtype=adbackend), 100
)

check_numerical(chain, ["ps[1]", "ps[2]"], [5 / 16, 11 / 16]; atol=0.015)
end

@testset "hmc reverse diff" begin
alg = HMC(0.1, 10; adtype=adbackend)
res = sample(rng, gdemo_default, alg, 4000)
res = sample(rng, gdemo_default, alg, 2000)
check_gdemo(res; rtol=0.1)
end

@testset "matrix support" begin
@model function hmcmatrixsup()
return v ~ Wishart(7, [1 0.5; 0.5 1])
end

model_f = hmcmatrixsup()
n_samples = 1_000
n_samples = 100
vs = map(1:3) do _
chain = sample(rng, model_f, HMC(0.15, 7; adtype=adbackend), n_samples)
r = reshape(Array(group(chain, :v)), n_samples, 2, 2)
Expand All @@ -79,6 +83,7 @@ using Turing

@test maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5
end

@testset "multivariate support" begin
# Define NN flow
function nn(x, b1, w11, w12, w13, bo, wo)
Expand Down Expand Up @@ -129,53 +134,42 @@ using Turing

@testset "hmcda inference" begin
alg1 = HMCDA(500, 0.8, 0.015; adtype=adbackend)
# alg2 = Gibbs(HMCDA(200, 0.8, 0.35, :m; adtype=adbackend), HMC(0.25, 3, :s; adtype=adbackend))

# alg3 = Gibbs(HMC(0.25, 3, :m; adtype=adbackend), PG(30, 3, :s))
# alg3 = PG(50, 2000)

res1 = sample(rng, gdemo_default, alg1, 3000)
check_gdemo(res1)

# res2 = sample(gdemo([1.5, 2.0]), alg2)
#
# @test mean(res2[:s]) ≈ 49/24 atol=0.2
# @test mean(res2[:m]) ≈ 7/6 atol=0.2
end

# TODO(mhauru) The below one is a) slow, b) flaky, in that changing the seed can
# easily make it fail, despite many more samples taken by most other tests.
@testset "hmcda+gibbs inference" begin
rng = StableRNG(123)
Random.seed!(12345) # particle samplers do not support user-provided `rng` yet
alg3 = Gibbs(PG(20, :s), HMCDA(500, 0.8, 0.25, :m; init_ϵ=0.05, adtype=adbackend))

res3 = sample(rng, gdemo_default, alg3, 3000; discard_initial=1000)
check_gdemo(res3)
Random.seed!(12345)
alg = Gibbs(PG(20, :s), HMCDA(500, 0.8, 0.25, :m; init_ϵ=0.05, adtype=adbackend))
res = sample(rng, gdemo_default, alg, 3000; discard_initial=1000)
check_gdemo(res)
end

@testset "hmcda constructor" begin
alg = HMCDA(0.8, 0.75; adtype=adbackend)
println(alg)
sampler = Sampler(alg, gdemo_default)
@test DynamicPPL.alg_str(sampler) == "HMCDA"

alg = HMCDA(200, 0.8, 0.75; adtype=adbackend)
println(alg)
sampler = Sampler(alg, gdemo_default)
@test DynamicPPL.alg_str(sampler) == "HMCDA"

alg = HMCDA(200, 0.8, 0.75, :s; adtype=adbackend)
println(alg)
sampler = Sampler(alg, gdemo_default)
@test DynamicPPL.alg_str(sampler) == "HMCDA"

@test isa(alg, HMCDA)
@test isa(sampler, Sampler{<:Turing.Hamiltonian})
end

@testset "nuts inference" begin
alg = NUTS(1000, 0.8; adtype=adbackend)
res = sample(rng, gdemo_default, alg, 6000)
res = sample(rng, gdemo_default, alg, 500)
check_gdemo(res)
end

@testset "nuts constructor" begin
alg = NUTS(200, 0.65; adtype=adbackend)
sampler = Sampler(alg, gdemo_default)
Expand All @@ -189,6 +183,7 @@ using Turing
sampler = Sampler(alg, gdemo_default)
@test DynamicPPL.alg_str(sampler) == "NUTS"
end

@testset "check discard" begin
alg = NUTS(100, 0.8; adtype=adbackend)

Expand All @@ -198,13 +193,14 @@ using Turing
@test size(c1, 1) == 500
@test size(c2, 1) == 500
end

@testset "AHMC resize" begin
alg1 = Gibbs(PG(10, :m), NUTS(100, 0.65, :s; adtype=adbackend))
alg2 = Gibbs(PG(10, :m), HMC(0.1, 3, :s; adtype=adbackend))
alg3 = Gibbs(PG(10, :m), HMCDA(100, 0.65, 0.3, :s; adtype=adbackend))
@test sample(rng, gdemo_default, alg1, 300) isa Chains
@test sample(rng, gdemo_default, alg2, 300) isa Chains
@test sample(rng, gdemo_default, alg3, 300) isa Chains
@test sample(rng, gdemo_default, alg1, 10) isa Chains
@test sample(rng, gdemo_default, alg2, 10) isa Chains
@test sample(rng, gdemo_default, alg3, 10) isa Chains
end

@testset "Regression tests" begin
Expand All @@ -213,28 +209,28 @@ using Turing
m = Matrix{T}(undef, 2, 3)
return m .~ MvNormal(zeros(2), I)
end
@test sample(rng, mwe1(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains
@test sample(rng, mwe1(), HMC(0.2, 4; adtype=adbackend), 100) isa Chains

@model function mwe2(::Type{T}=Matrix{Float64}) where {T}
m = T(undef, 2, 3)
return m .~ MvNormal(zeros(2), I)
end
@test sample(rng, mwe2(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains
@test sample(rng, mwe2(), HMC(0.2, 4; adtype=adbackend), 100) isa Chains

# https://github.com/TuringLang/Turing.jl/issues/1308
@model function mwe3(::Type{T}=Array{Float64}) where {T}
m = T(undef, 2, 3)
return m .~ MvNormal(zeros(2), I)
end
@test sample(rng, mwe3(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains
@test sample(rng, mwe3(), HMC(0.2, 4; adtype=adbackend), 100) isa Chains
end

# issue #1923
@testset "reproducibility" begin
alg = NUTS(1000, 0.8; adtype=adbackend)
res1 = sample(StableRNG(123), gdemo_default, alg, 1000)
res2 = sample(StableRNG(123), gdemo_default, alg, 1000)
res3 = sample(StableRNG(123), gdemo_default, alg, 1000)
res1 = sample(rng, gdemo_default, alg, 10)
res2 = sample(rng, gdemo_default, alg, 10)
res3 = sample(rng, gdemo_default, alg, 10)
@test Array(res1) == Array(res2) == Array(res3)
end

Expand All @@ -249,7 +245,7 @@ using Turing
gdemo_default_prior = DynamicPPL.contextualize(
demo_hmc_prior(), DynamicPPL.PriorContext()
)
chain = sample(gdemo_default_prior, alg, 10_000; initial_params=[3.0, 0.0])
chain = sample(gdemo_default_prior, alg, 500; initial_params=[3.0, 0.0])
check_numerical(
chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0]; atol=0.2
)
Expand Down Expand Up @@ -288,7 +284,7 @@ using Turing
return xs[2] ~ Dirichlet(ones(5))
end
model = vector_of_dirichlet()
chain = sample(model, NUTS(), 1000)
chain = sample(model, NUTS(), 100)
@test mean(Array(chain)) 0.2
end

Expand All @@ -311,7 +307,7 @@ using Turing
end

model = buggy_model()
num_samples = 1_000
num_samples = 100

chain = sample(model, NUTS(), num_samples; initial_params=[0.5, 1.75, 1.0])
chain_prior = sample(model, Prior(), num_samples)
Expand Down

0 comments on commit 320bbd0

Please sign in to comment.