From 4293af304f6e31e1105ce4bc368f071174bc8fa1 Mon Sep 17 00:00:00 2001 From: Paulina Martin Date: Mon, 5 Jul 2021 14:43:28 -0500 Subject: [PATCH 1/6] Add violin plots --- src/plot.jl | 47 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/src/plot.jl b/src/plot.jl index 41edb8f2..d6cfbb6f 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -4,12 +4,14 @@ @shorthands pooleddensity @shorthands traceplot @shorthands corner +@shorthands violinplot struct _TracePlot; c; val; end struct _MeanPlot; c; val; end struct _DensityPlot; c; val; end struct _HistogramPlot; c; val; end struct _AutocorPlot; lags; val; end +struct _ViolinPlot; parameters; val; end # define alias functions for old syntax const translationdict = Dict( @@ -18,7 +20,8 @@ const translationdict = Dict( :density => _DensityPlot, :histogram => _HistogramPlot, :autocorplot => _AutocorPlot, - :pooleddensity => _DensityPlot + :pooleddensity => _DensityPlot, + :violinplot => _ViolinPlot ) const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :corner) @@ -184,3 +187,45 @@ end ar = collect(Array(corner.c.value[:, corner.parameters,i]) for i in chains(corner.c)) RecipesBase.recipetype(:cornerplot, vcat(ar...)) end + +@recipe function f( + chains::Chains; + sections = chains.name_map[:parameters], + combined = true +) + + st = get(plotattributes, :seriestype, :traceplot) + if st == :violinplot + if combined + parameters = string.(sections) + val = Array(chains)[:, ] + _ViolinPlot(parameters, val) + + elseif combined == false + data = Array(chains, append_chains = false) + parameters = vec(["param $(sections[i]).Chain $j" + for i in 1:length(sections), + j in 1:length(data)]) + val_vec = vec([data[j][:,i] for i in 1:length(sections), j in 1:length(data)]) + n_iter = length(val_vec[1]) + n_chains = length(val_vec) + val = zeros(Float64, n_iter, n_chains) + for i in 1:n_iter + for j in 1:n_chains + val[i,j] = val_vec[j][i] + end + end + _ViolinPlot(parameters, val[:,]) + else + error("Symbol names are interpreted as parameter names, only compatible with ", + "`colordim = :chain`") + end + end +end + +@recipe function f(p::_ViolinPlot) + seriestype := :violin + xaxis --> "Parameter" + p.parameters, p.val + #[collect(skipmissing(p.val[:,k])) for k in 1:size(p.val)] +end From 4674ea5b2c97eda76e7e39ec723515e385c975c7 Mon Sep 17 00:00:00 2001 From: Paulina Martin Date: Mon, 5 Jul 2021 22:40:36 -0500 Subject: [PATCH 2/6] Modify default size and code refactoring --- src/plot.jl | 49 ++++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/src/plot.jl b/src/plot.jl index d6cfbb6f..c95e6798 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -11,7 +11,7 @@ struct _MeanPlot; c; val; end struct _DensityPlot; c; val; end struct _HistogramPlot; c; val; end struct _AutocorPlot; lags; val; end -struct _ViolinPlot; parameters; val; end +struct _ViolinPlot; parameters; val; total_chains; end # define alias functions for old syntax const translationdict = Dict( @@ -195,27 +195,28 @@ end ) st = get(plotattributes, :seriestype, :traceplot) + total_chains = 0 if st == :violinplot if combined parameters = string.(sections) val = Array(chains)[:, ] - _ViolinPlot(parameters, val) - + total_chains = Integer(size(chains.value.data)[3]) + _ViolinPlot(parameters, val, total_chains) elseif combined == false - data = Array(chains, append_chains = false) - parameters = vec(["param $(sections[i]).Chain $j" - for i in 1:length(sections), - j in 1:length(data)]) - val_vec = vec([data[j][:,i] for i in 1:length(sections), j in 1:length(data)]) + chain_arr = Array(chains, append_chains = false) + parameters = ["param $(sections[i]).Chain $j" + for i in 1:length(sections) + for j in 1:length(chain_arr)] + val_vec = [chain_arr[j][:,i] + for i in 1:length(sections) + for j in 1:length(chain_arr)] n_iter = length(val_vec[1]) - n_chains = length(val_vec) - val = zeros(Float64, n_iter, n_chains) - for i in 1:n_iter - for j in 1:n_chains - val[i,j] = val_vec[j][i] - end + total_chains = length(val_vec) + val = zeros(Float64, n_iter, total_chains) + for i in 1:total_chains + val[:,i] = val_vec[:][i] end - _ViolinPlot(parameters, val[:,]) + _ViolinPlot(parameters, val[:,], total_chains) else error("Symbol names are interpreted as parameter names, only compatible with ", "`colordim = :chain`") @@ -224,8 +225,18 @@ end end @recipe function f(p::_ViolinPlot) - seriestype := :violin - xaxis --> "Parameter" - p.parameters, p.val - #[collect(skipmissing(p.val[:,k])) for k in 1:size(p.val)] + @series begin + seriestype := :violin + xaxis --> "Parameter" + size --> (150*p.total_chains, 500) + p.parameters, p.val + end + + @series begin + seriestype := :boxplot + bar_width --> 0.1 + linewidth --> 2 + fillalpha --> 0.8 + p.parameters, p.val + end end From 01d56d52d145e85fbe96444b7b1b96730a46791a Mon Sep 17 00:00:00 2001 From: Paulina Martin Date: Thu, 8 Jul 2021 16:05:57 -0500 Subject: [PATCH 3/6] Correct input variables of _ViolinPlot --- src/plot.jl | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/plot.jl b/src/plot.jl index c95e6798..b9bf96dd 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -190,7 +190,7 @@ end @recipe function f( chains::Chains; - sections = chains.name_map[:parameters], + sections::Vector{Symbol} = chains.name_map[:parameters], combined = true ) @@ -198,17 +198,16 @@ end total_chains = 0 if st == :violinplot if combined - parameters = string.(sections) - val = Array(chains)[:, ] + n_iter, n_parameters = size(Array(chains)) + parameters = string.(repeat(sections, inner = n_iter)) + val = vec(Array(chains)) total_chains = Integer(size(chains.value.data)[3]) _ViolinPlot(parameters, val, total_chains) elseif combined == false + n_parameters = length(sections) chain_arr = Array(chains, append_chains = false) - parameters = ["param $(sections[i]).Chain $j" - for i in 1:length(sections) - for j in 1:length(chain_arr)] val_vec = [chain_arr[j][:,i] - for i in 1:length(sections) + for i in 1:n_parameters for j in 1:length(chain_arr)] n_iter = length(val_vec[1]) total_chains = length(val_vec) @@ -216,7 +215,12 @@ end for i in 1:total_chains val[:,i] = val_vec[:][i] end - _ViolinPlot(parameters, val[:,], total_chains) + val = vec(val) + parameters_names = ["param $(sections[i]).Chain $j" + for i in 1:n_parameters + for j in 1:length(chain_arr)] + parameters = string.(repeat(parameters_names, inner = n_iter)) + _ViolinPlot(parameters, val, total_chains) else error("Symbol names are interpreted as parameter names, only compatible with ", "`colordim = :chain`") @@ -228,7 +232,7 @@ end @series begin seriestype := :violin xaxis --> "Parameter" - size --> (150*p.total_chains, 500) + size --> (200*p.total_chains, 500) p.parameters, p.val end From e4c9765f32fef64be15bf7b8051954b525caf5f7 Mon Sep 17 00:00:00 2001 From: Paulina Martin Date: Mon, 12 Jul 2021 20:35:26 -0500 Subject: [PATCH 4/6] Fix warning message --- src/plot.jl | 85 ++++++++++++++++++++++++----------------------------- 1 file changed, 39 insertions(+), 46 deletions(-) diff --git a/src/plot.jl b/src/plot.jl index b9bf96dd..4d3a44c9 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -11,7 +11,7 @@ struct _MeanPlot; c; val; end struct _DensityPlot; c; val; end struct _HistogramPlot; c; val; end struct _AutocorPlot; lags; val; end -struct _ViolinPlot; parameters; val; total_chains; end +struct _ViolinPlot; par; val; end # define alias functions for old syntax const translationdict = Dict( @@ -33,7 +33,9 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor colordim = :chain, barbounds = (-Inf, Inf), maxlag = nothing, - append_chains = false + append_chains = false, + sections = chains.name_map[:parameters], + combined = true ) st = get(plotattributes, :seriestype, :traceplot) c = append_chains || st == :pooleddensity ? pool_chain(chains) : chains @@ -72,6 +74,39 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor else range(c), val end + + total_chains = i + if st == :violinplot + n_iter, n_par, n_chains = size(chains) + if combined + colordim := :chain + par = string.(reshape(repeat(sections, inner = n_iter), n_iter, n_par))[:,i] + val = Array(chains)[:,i] + _ViolinPlot(par, val) + elseif combined == false + if colordim == :chain + par_names = ["$(sections[i]).Chain $j" for i in 1:n_par, j in 1:n_chains] + pars = string.(reshape(repeat(vec(par_names), inner = n_iter), (n_iter, n_par, n_chains))) + val = chains.value[:,i,:] + par = pars[:,i,:] + elseif colordim == :parameter + par_vec = repeat(sections, inner = n_iter) + pars = string.(reshape(repeat(par_vec, n_chains, 1), (n_iter, n_par, n_chains))) + val = chains.value[:,:,i] + par = pars[:,:,i] + label --> string.(names(c)) + else + throw(ArgumentError("`colordim` must be one of `:chain` or `:parameter`")) + end + _ViolinPlot(par, val) + else + throw(ArgumentError("In `ViolinPlots` `Chains` can be combined or separated ")) + end + elseif st ∈ supportedplots + translationdict[st](c, val) + else + range(c), val + end end @recipe function f(p::_DensityPlot) @@ -188,52 +223,10 @@ end RecipesBase.recipetype(:cornerplot, vcat(ar...)) end -@recipe function f( - chains::Chains; - sections::Vector{Symbol} = chains.name_map[:parameters], - combined = true -) - - st = get(plotattributes, :seriestype, :traceplot) - total_chains = 0 - if st == :violinplot - if combined - n_iter, n_parameters = size(Array(chains)) - parameters = string.(repeat(sections, inner = n_iter)) - val = vec(Array(chains)) - total_chains = Integer(size(chains.value.data)[3]) - _ViolinPlot(parameters, val, total_chains) - elseif combined == false - n_parameters = length(sections) - chain_arr = Array(chains, append_chains = false) - val_vec = [chain_arr[j][:,i] - for i in 1:n_parameters - for j in 1:length(chain_arr)] - n_iter = length(val_vec[1]) - total_chains = length(val_vec) - val = zeros(Float64, n_iter, total_chains) - for i in 1:total_chains - val[:,i] = val_vec[:][i] - end - val = vec(val) - parameters_names = ["param $(sections[i]).Chain $j" - for i in 1:n_parameters - for j in 1:length(chain_arr)] - parameters = string.(repeat(parameters_names, inner = n_iter)) - _ViolinPlot(parameters, val, total_chains) - else - error("Symbol names are interpreted as parameter names, only compatible with ", - "`colordim = :chain`") - end - end -end - @recipe function f(p::_ViolinPlot) @series begin seriestype := :violin - xaxis --> "Parameter" - size --> (200*p.total_chains, 500) - p.parameters, p.val + p.par, p.val end @series begin @@ -241,6 +234,6 @@ end bar_width --> 0.1 linewidth --> 2 fillalpha --> 0.8 - p.parameters, p.val + p.par, p.val end end From 1eb9c58a0b2ce527c56fe199b16f76063473823b Mon Sep 17 00:00:00 2001 From: Paulina Martin Date: Tue, 13 Jul 2021 15:03:57 -0500 Subject: [PATCH 5/6] Bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2efb10e4..56a4b19a 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "4.12.0" +version = "4.15.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From 1bb0f93add25d92d772f69c739e1cf298e6b3453 Mon Sep 17 00:00:00 2001 From: Paulina Martin Date: Wed, 28 Jul 2021 21:59:49 -0500 Subject: [PATCH 6/6] Fix error in tests and add a test for violinplot --- src/plot.jl | 16 +++++----------- test/plot_test.jl | 17 ++++++++++------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/src/plot.jl b/src/plot.jl index 4d3a44c9..fc1e3232 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -34,7 +34,7 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor barbounds = (-Inf, Inf), maxlag = nothing, append_chains = false, - sections = chains.name_map[:parameters], + par_sections = chains.name_map[:parameters], combined = true ) st = get(plotattributes, :seriestype, :traceplot) @@ -69,28 +69,22 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor ac_mat = convert(Array, ac) val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :] _AutocorPlot(lags, val) - elseif st ∈ supportedplots - translationdict[st](c, val) - else - range(c), val - end - total_chains = i - if st == :violinplot + elseif st == :violinplot n_iter, n_par, n_chains = size(chains) if combined colordim := :chain - par = string.(reshape(repeat(sections, inner = n_iter), n_iter, n_par))[:,i] + par = string.(reshape(repeat(par_sections, inner = n_iter), n_iter, n_par))[:,i] val = Array(chains)[:,i] _ViolinPlot(par, val) elseif combined == false if colordim == :chain - par_names = ["$(sections[i]).Chain $j" for i in 1:n_par, j in 1:n_chains] + par_names = ["$(par_sections[i]).Chain $j" for i in 1:n_par, j in 1:n_chains] pars = string.(reshape(repeat(vec(par_names), inner = n_iter), (n_iter, n_par, n_chains))) val = chains.value[:,i,:] par = pars[:,i,:] elseif colordim == :parameter - par_vec = repeat(sections, inner = n_iter) + par_vec = repeat(par_sections, inner = n_iter) pars = string.(reshape(repeat(par_vec, n_chains, 1), (n_iter, n_par, n_chains))) val = chains.value[:,:,i] par = pars[:,:,i] diff --git a/test/plot_test.jl b/test/plot_test.jl index 9654ed8c..bb142e5d 100644 --- a/test/plot_test.jl +++ b/test/plot_test.jl @@ -24,29 +24,32 @@ Logging.disable_logging(Logging.Warn) println("traceplot") display(traceplot(chn, 1)) println() - + println("meanplot") display(meanplot(chn, 1)) println() - + println("density") display(density(chn, 1)) display(density(chn, 1, append_chains=true)) println() - + println("autocorplot") display(autocorplot(chn, 1)) println() - + #ps_contour = plot(chn, :contour) println("histogram") display(histogram(chn, 1)) println() - + println("\nmixeddensity") display(mixeddensity(chn, 1)) - + + println("violinplot") + display(violinplot(chn)) + println() # plotting combinations display(plot(chn)) display(plot(chn, append_chains=true)) @@ -54,7 +57,7 @@ Logging.disable_logging(Logging.Warn) # Test plotting using colordim keyword display(plot(chn, colordim = :parameter)) - + # Test if plotting a sub-set work.s display(plot(chn, 2)) display(plot(chn, 2, colordim = :parameter))