Skip to content

Commit

Permalink
fix docs
Browse files Browse the repository at this point in the history
  • Loading branch information
marius311 committed Apr 23, 2024
1 parent d7a7ff6 commit 1c23df3
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 33 deletions.
9 changes: 1 addition & 8 deletions .github/workflows/tests_and_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,14 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
julia-version: ['1.7', '1.8', '1.9', '~1.10.0-0']
julia-version: ['1.7', '1.8', '1.9', '1.10']
threads: ['1', '2']
fail-fast: false
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.julia-version }}
- uses: actions/setup-python@v2
with:
python-version: '3.8'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install matplotlib
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
Expand Down
32 changes: 8 additions & 24 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ First, load up the packages we'll need:

```@example 1
using MuseInference, Turing
using AbstractDifferentiation, Dates, LinearAlgebra, Printf, PyPlot, Random, Zygote
using AbstractDifferentiation, Dates, LinearAlgebra, Printf, Plots, Random, Zygote
Turing.setadbackend(:zygote)
PyPlot.ioff() # hide
using Logging # hide
Logging.disable_logging(Logging.Info) # hide
Turing.AdvancedVI.PROGRESS[] = false # hide
Expand Down Expand Up @@ -85,7 +84,7 @@ nothing # hide
We next compute the MUSE estimate for the same problem. To reach the same Monte Carlo error as HMC, the number of MUSE simulations should be the same as the effective sample size of the chain we just ran. This is:

```@example 1
nsims = round(Int, ess_rhat(chain)[:θ,:ess])
nsims = round(Int, ess(chain)[:θ,:ess])
```

Running the MUSE estimate,
Expand All @@ -97,29 +96,14 @@ muse_result = @time muse(model, 0; nsims, get_covariance=true)
nothing # hide
```

Lets also try mean-field variational inference (MFVI) to compare to another approximate method.
Now let's plot the different estimates. In this case, MUSE gives a nearly perfect answer in a fraction of the time.

```@example 1
Random.seed!(4)
vi(model, ADVI(10, 10)) # warmup # hide
t_vi = @time @elapsed vi_result = vi(model, ADVI(10, 1000))
nothing # hide
```

Now let's plot the different estimates. In this case, MUSE gives a nearly perfect answer at a fraction of the computational cost. MFVI struggles in both speed and accuracy by comparison.

```@example 1
figure(figsize=(6,5)) # hide
axvline(0, c="k", ls="--", alpha=0.5)
hist(collect(chain["θ"][:]), density=true, bins=15, label=@sprintf("HMC (%.1f seconds)", chain.info.stop_time - chain.info.start_time))
histogram(collect(chain["θ"][:]), normalize=:pdf, bins=10, label=@sprintf("HMC (%.1f seconds)", chain.info.stop_time - chain.info.start_time))
θs = range(-0.5,0.5,length=1000)
plot(θs, pdf.(muse_result.dist, θs), label=@sprintf("MUSE (%.1f seconds)", (muse_result.time / Millisecond(1000))))
plot(θs, pdf.(Normal(vi_result.dist.m[1], vi_result.dist.σ[1]), θs), label=@sprintf("MFVI (%.1f seconds)", t_vi))
legend()
xlabel(L"\theta")
ylabel(L"\mathcal{P}(\theta\,|\,x)")
title("2048-dimensional noisy funnel")
gcf() # hide
plot!(θs, pdf.(muse_result.dist, θs), label=@sprintf("MUSE (%.1f seconds)", (muse_result.time / Millisecond(1000))), lw=2)
vline!([0], c=:black, ls=:dash, alpha=0.5, label=nothing)
plot!(xlabel=L"\theta", ylabel=L"\mathcal{P}(\theta\,|\,x)", title="2048-dimensional noisy funnel")
```

The timing[^1] difference is indicative of the speedups over HMC that are possible. These get even more dramatic as we increase dimensionality, which is why MUSE really excels on high-dimensional problems.
Expand Down Expand Up @@ -180,7 +164,7 @@ prob = SimpleMuseProblem(
function logPrior(θ)
-θ^2/(2*3^2)
end;
autodiff = AD.ZygoteBackend()
autodiff = AbstractDifferentiation.ZygoteBackend()
)
nothing # hide
```
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MeasureTheory = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Soss = "8ce77f84-9b61-11e8-39ff-d17a774bf41c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down

0 comments on commit 1c23df3

Please sign in to comment.