Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use coordinate names in posterior summary() #1091

Closed
ianmtaylor1 opened this issue Feb 26, 2020 · 2 comments · Fixed by #1201
Closed

Use coordinate names in posterior summary() #1091

ianmtaylor1 opened this issue Feb 26, 2020 · 2 comments · Fixed by #1201

Comments

@ianmtaylor1
Copy link

Plotting functions such as plot_posterior() and plot_trace() identify coordinates in variables by their coordinate names. But summary() does not. Using coordinate names in the summary() method would result in clearer output.

Code to demonstrate:

import pystan
import arviz
import numpy

print(arviz.__version__)

code = """
data {
    int<lower=1> N;
    int<lower=1> P;
    vector[N] y;
    matrix[N,P] X;
}
parameters {
    vector[P] beta;
    real alpha;
    real<lower=0> sigma;
}
model {
    beta ~ std_normal();
    alpha ~ std_normal();
    sigma ~ cauchy(0, 1) T[0,];
    y ~ normal_id_glm(X, alpha, beta, sigma);
}
"""

# Compile model
model = pystan.StanModel(model_code=code)

#### Data
# Dimensions
P = 4
N = 200
# Parameters
alpha = 1
beta = numpy.array(range(P)) - (P - 1)/2
sigma = 2
# Data
rg = numpy.random.default_rng()
X = rg.normal(size=(N,P))
y = numpy.matmul(X, beta) + alpha + rg.normal(scale=sigma, size=N)

# Sample
fit = model.sampling(
    data={'N':N, 'P':P, 'X':X, 'y':y},
    init='random', n_jobs=1, 
    pars=['alpha','beta','sigma'])

samples = arviz.from_pystan(fit,
                            observed_data='y',
                            constant_data='X',
                            coords={'observation':list(range(200)), 
                                    'covariate':['A','B','C','D']}, 
                            dims={'y':['observation'], 
                                  'beta':['covariate'], 
                                  'X':['observation','covariate']})

arviz.plot_posterior(samples)
arviz.plot_trace(samples)
print(arviz.summary(samples))

Output

0.6.1
Figure_1
Figure_2

          mean     sd  hpd_3%  hpd_97%  mcse_mean  mcse_sd  ess_mean  ess_sd  ess_bulk  ess_tail  r_hat
alpha    1.026  0.144   0.752    1.301      0.002    0.001    5294.0  5294.0    5258.0    2582.0    1.0
beta[0] -1.504  0.148  -1.800   -1.239      0.002    0.001    5550.0  5293.0    5570.0    3159.0    1.0
beta[1] -0.531  0.153  -0.825   -0.252      0.002    0.002    5597.0  5040.0    5582.0    3062.0    1.0
beta[2]  0.396  0.140   0.143    0.671      0.002    0.001    5700.0  5498.0    5682.0    3168.0    1.0
beta[3]  1.086  0.155   0.801    1.390      0.002    0.001    5714.0  5514.0    5717.0    3156.0    1.0
sigma    2.073  0.104   1.895    2.287      0.001    0.001    5101.0  5024.0    5168.0    3039.0    1.0

Desired output

(plots unchanged)

          mean     sd  hpd_3%  hpd_97%  mcse_mean  mcse_sd  ess_mean  ess_sd  ess_bulk  ess_tail  r_hat
alpha    1.026  0.144   0.752    1.301      0.002    0.001    5294.0  5294.0    5258.0    2582.0    1.0
beta[A] -1.504  0.148  -1.800   -1.239      0.002    0.001    5550.0  5293.0    5570.0    3159.0    1.0
beta[B] -0.531  0.153  -0.825   -0.252      0.002    0.002    5597.0  5040.0    5582.0    3062.0    1.0
beta[C]  0.396  0.140   0.143    0.671      0.002    0.001    5700.0  5498.0    5682.0    3168.0    1.0
beta[D]  1.086  0.155   0.801    1.390      0.002    0.001    5714.0  5514.0    5717.0    3156.0    1.0
sigma    2.073  0.104   1.895    2.287      0.001    0.001    5101.0  5024.0    5168.0    3039.0    1.0

Alternatively, a pandas multiindex could be used:

            mean     sd  hpd_3%  hpd_97%  mcse_mean  mcse_sd  ess_mean  ess_sd  ess_bulk  ess_tail  r_hat
alpha      1.026  0.144   0.752    1.301      0.002    0.001    5294.0  5294.0    5258.0    2582.0    1.0
beta    A -1.504  0.148  -1.800   -1.239      0.002    0.001    5550.0  5293.0    5570.0    3159.0    1.0
        B -0.531  0.153  -0.825   -0.252      0.002    0.002    5597.0  5040.0    5582.0    3062.0    1.0
        C  0.396  0.140   0.143    0.671      0.002    0.001    5700.0  5498.0    5682.0    3168.0    1.0
        D  1.086  0.155   0.801    1.390      0.002    0.001    5714.0  5514.0    5717.0    3156.0    1.0
sigma      2.073  0.104   1.895    2.287      0.001    0.001    5101.0  5024.0    5168.0    3039.0    1.0
@ahartikainen
Copy link
Contributor

Would extra column work? Multi-index is harder to handle than normal index.

@ianmtaylor1
Copy link
Author

I think an extra column would be perfectly fine.

If changing the structure of the data frame would cause unwanted side effects, then I think the simplest approach could be the first alternate output, replacing "0", "1", etc with the coordinate names in the text of the index. No extra columns or multiindex would be needed in that case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants