-
Notifications
You must be signed in to change notification settings - Fork 0
/
run-example.py
executable file
·111 lines (99 loc) · 3.1 KB
/
run-example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python3
import numpy as np
import polars as pl
import plotnine as pn
import pypfilt
import scipy.stats
scenario_file = 'example.toml'
instances = {
instance.scenario_id: instance
for instance in pypfilt.load_instances(scenario_file)
}
# Generate simulated observations.
obs_tables = pypfilt.simulate_from_model(instances['simulate'])
rolls = obs_tables['roll']
rolls_file = 'simulated-rolls.ssv'
pypfilt.io.write_table(rolls_file, rolls, time_scale=pypfilt.Scalar())
# Generate prior samples from a Dirichlet distribution.
# See https://en.wikipedia.org/wiki/Categorical_distribution, in particular
# the "Bayesian inference using conjugate prior" section.
rng = np.random.default_rng(seed=12345)
prior_file = 'prior.ssv'
dist = scipy.stats.dirichlet(alpha=[1, 1, 1, 1, 1, 1])
num_particles = instances['example'].settings['filter']['particles']
prior_samples = dist.rvs(size=num_particles, random_state=rng)
np.savetxt(
prior_file, prior_samples, header='p_1 p_2 p_3 p_4 p_5 p_6', comments=''
)
# Fit to the observations
context = instances['example'].build_context()
results = pypfilt.fit(context, filename=None)
# Collect the credible intervals for each outcome, over time.
cints = results.estimation.tables['model_cints']
df_cints = (
pl.from_numpy(cints)
.select(['time', 'name', 'prob', 'ymin', 'ymax'])
.with_columns(
name=pl.col('name').apply(str),
ci=pl.lit(0).sub(pl.col('prob')),
)
)
# Define the true probabilities, from which we simulated the observations.
df_true = pl.DataFrame(
{
'name': ['p_1', 'p_2', 'p_3', 'p_4', 'p_5', 'p_6'],
'truth': [0.1, 0.1, 0.1, 0.1, 0.1, 0.5],
}
)
# Calculate how often each outcome was observed.
df_rolls = pl.from_numpy(rolls)
df_fracs = (
df_rolls.with_columns(name=pl.col('value').apply(lambda v: f'p_{v}'))
.groupby('name')
.agg(frac=pl.col('time').count().truediv(pl.lit(df_rolls.height)))
)
# Define custom breaks and labels for each credible interval.
breaks = list(df_cints['prob'].unique().sort())[::-1]
labels = [f'{prob}%' for prob in breaks]
# Plot the credible intervals against the ground truth and observed outcomes.
plot = (
pn.ggplot()
+ pn.geom_ribbon(
df_cints,
pn.aes(
'time',
ymin='ymin',
ymax='ymax',
fill='prob',
colour='prob',
group='ci',
),
)
+ pn.geom_hline(df_true, pn.aes(yintercept='truth'))
+ pn.geom_hline(df_fracs, pn.aes(yintercept='frac'), linetype='dashed')
+ pn.facet_wrap('name', labeller=lambda s: f'Pr(Roll a {s[-1]})')
+ pn.xlab('Number of observations')
+ pn.ylab('Probability')
+ pn.scale_fill_continuous(
name='CrI',
breaks=breaks,
labels=labels,
)
+ pn.scale_colour_continuous(
name='CrI',
breaks=breaks,
labels=labels,
)
+ pn.guides(color=pn.guide_legend(), fill=pn.guide_legend())
)
plot_file = 'example.png'
print(f'Saving {plot_file} ...')
plot.save(
plot_file,
width=8,
height=6,
units='in',
dpi=300,
verbose=False,
metadata={'Software': None},
)