Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting more MCMCChains variable names #211

Open
sethaxen opened this issue Aug 8, 2022 · 9 comments
Open

Supporting more MCMCChains variable names #211

sethaxen opened this issue Aug 8, 2022 · 9 comments
Labels
discussion enhancement New feature or request

Comments

@sethaxen
Copy link
Member

sethaxen commented Aug 8, 2022

Currently from_mcmcchains assumes that all variable names are single-bracket-delimited or dot-delimited. However, quite complicated names are possible:

julia> using Turing, LinearAlgebra

julia> @model function foo()
           a ~ Normal()
           bar = (b=Matrix{typeof(a)}(undef, 2, 3), c=Vector{typeof(a)}(undef, 3))
           bar.b[1, :] .~ Normal(a)
           bar.b[2, 1:2] ~ MvNormal(I(2))
           bar.b[2, 3:end] .~ Normal(a)
           bar.c[end:-1:1] .~ Normal(a)
       end
foo (generic function with 2 methods)

julia> chns = sample(foo(), NUTS(), 1_000)
┌ Info: Found initial step size
└   ϵ = 1.6
Sampling 100%|█████████████████████████████████████████████████████████████████████████| Time: 0:00:01
Chains MCMC chain (1000×22×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 1.21 seconds
Compute duration  = 1.21 seconds
parameters        = a, bar.b[1,:][1], bar.b[1,:][2], bar.b[1,:][3], bar.b[2,1:2][1], bar.b[2,1:2][2], bar.b[2,3:3][1], bar.c[3:-1:1][1], bar.c[3:-1:1][2], bar.c[3:-1:1][3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
        parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
            Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64 

                 a    0.0278    0.9966     0.0315    0.0724    139.5528    1.0152      115.1426
     bar.b[1,:][1]    0.0178    1.4143     0.0447    0.0841    205.0380    1.0096      169.1733
     bar.b[1,:][2]    0.0268    1.3678     0.0433    0.0765    292.3916    1.0072      241.2472
     bar.b[1,:][3]    0.0395    1.3866     0.0438    0.0852    234.2846    1.0061      193.3041
   bar.b[2,1:2][1]    0.0395    1.0091     0.0319    0.0260   1017.5267    0.9990      839.5435
   bar.b[2,1:2][2]   -0.0270    0.9839     0.0311    0.0260   1495.1411    0.9990     1233.6147
   bar.b[2,3:3][1]    0.0506    1.3798     0.0436    0.0724    347.4877    1.0032      286.7061
  bar.c[3:-1:1][1]    0.0156    1.4357     0.0454    0.0835    275.8889    1.0079      227.6311
  bar.c[3:-1:1][2]   -0.0010    1.4186     0.0449    0.0720    362.6178    1.0072      299.1896
  bar.c[3:-1:1][3]    0.0146    1.4061     0.0445    0.0783    333.0312    1.0069      274.7783

Quantiles
        parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
            Symbol   Float64   Float64   Float64   Float64   Float64 

                 a   -1.8465   -0.6502   -0.0164    0.7037    2.0664
     bar.b[1,:][1]   -2.7699   -0.9711    0.0028    0.9771    2.7813
     bar.b[1,:][2]   -2.5305   -0.9451    0.0408    0.9232    2.7366
     bar.b[1,:][3]   -2.6779   -0.9212    0.0311    1.0146    2.7213
   bar.b[2,1:2][1]   -1.8999   -0.6486    0.0397    0.7054    2.1453
   bar.b[2,1:2][2]   -1.9572   -0.7017   -0.0031    0.6454    1.8877
   bar.b[2,3:3][1]   -2.6416   -0.8997   -0.0014    0.9453    3.0213
  bar.c[3:-1:1][1]   -2.8470   -0.9299   -0.0128    0.9509    2.8362
  bar.c[3:-1:1][2]   -2.8268   -0.9619    0.0671    0.9812    2.7441
  bar.c[3:-1:1][3]   -2.7313   -0.9267    0.0001    0.9474    2.8913

If we call from_mcmcchains on this, we get an uninformative error.

Ideally we would like to get an InferenceData with bar.b and bar.c as variables. However, since the modeler can arbitrarily index and reindex and call getproperty to make arbitrarily complicated types, always doing the right thing is probably not possible. Also, MCMCChains's own machinery for combining flattened parameters into parameter arrays doesn't do a great job here.

For the short term then, I think it makes the most sense to raise an informative error of splitting by brackets produces anything more complicated than a tuple of integer indices. If users find this constraining and open issues, we can discuss supporting slightly more complicated indexing syntaxes.

@sethaxen sethaxen added enhancement New feature or request discussion labels Aug 8, 2022
@SamuelBrand1
Copy link

Was there ever a resolution on this? I think I'm hitting an issue with a project I'm working on.

@sethaxen
Copy link
Member Author

No, we haven't implemented a solution for this (mostly because no one else reported this being an issue). But now that LKJCholesky draws will appear in Chains with a syntax like F.L[i, j], which I think would hit this issue, it's a bit higher priority.

The long-term goal was to make InferenceData a chaintype that could be returned by sample instead of Chains, so we could maybe avoid Chains's flattening altogether. TuringLang/DynamicPPL.jl#464 But that's a much bigger project I haven't had the focused time to finish.

Can you share what your Chains object looks like? I might be able to come up with a workaround a quick patch for that case.

@SamuelBrand1
Copy link

Hey @sethaxen

so a slice of an example fast run gives names like this (yes I know the convergence is horrible here ;-) ):

Summary Statistics
             parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ess_per_sec 
                 Symbol   Float64   Float64   Float64     Float64     Float64   Float64       Float64 

  latent.latent_init[1]   -0.0390    0.0928    0.0246     14.7598     65.6111    1.2054        0.0129
            latent.σ_AR    0.4779    0.2497    0.1204      4.8553     11.2507    2.5235        0.0042
      latent.ar_init[1]    0.0588    0.1163    0.0403      8.8253     15.4617    1.3860        0.0077
      latent.damp_AR[1]    0.4611    0.1653    0.0787      4.8405     11.6086    2.5636        0.0042
          latent.ϵ_t[1]   -0.0325    0.2224    0.0860      6.1744     27.2663    1.7365        0.0054
          latent.ϵ_t[2]   -0.6047    0.6202    0.2991      4.6042     11.6951    3.0378        0.0040
         init_incidence    4.6052    0.0000    0.0000   2131.9378   1288.5933    1.0015        1.8563
                obs.std    0.2493    0.3386    0.1667      4.8690     11.5049    2.4899        0.0042
             obs.ϵ_t[1]    0.1988    0.4118    0.1979      5.6882     15.6413    1.9164        0.0050

@sethaxen
Copy link
Member Author

Okay, I think we can support this by removing support for the old . syntax for separating indices. Only Stan sample/CmdStan ever used that, and now they have native InferenceData support. I'll put together a PR.

@SamuelBrand1
Copy link

Okay, I think we can support this by removing support for the old . syntax for separating indices. Only Stan sample/CmdStan ever used that, and now they have native InferenceData support. I'll put together a PR.

Wow! Quick work. Thanks.

@sethaxen
Copy link
Member Author

@SamuelBrand1, with the just-released ArviZ v0.11.0, you shouldn't have any issues with your example model, but let me know if you run into any. (e.g. instead you would get the variable names a, bar.b[1, :], bar.b[2, 1:2], etc.)

Note that many of the fancy indexing cases in #211 (comment) will now no longer error but will instead just not be processed into multi-dimensional arrays.

@SamuelBrand1
Copy link

Thanks so much, I'll give this a whirl now

@SamuelBrand1
Copy link

It works for our use case! Thanks so much for the speedy work here.

@sethaxen
Copy link
Member Author

No problem!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants