-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot_parameters_figures.py
100 lines (73 loc) · 3.81 KB
/
plot_parameters_figures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from matplotlib.dates import date2num, num2date
from matplotlib import dates as mdates
from matplotlib import ticker
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch
from matplotlib import pyplot as plt
from global_config import config
from functions.adjust_cases_functions import prepare_cases
from models.seird_model import SEIRD
import matplotlib.pyplot as plt
import scipy.io as sio
import pandas as pd
import numpy as np
import os
from global_config import config
from functions.samples_utils import create_df_response
import sys
if len(sys.argv) < 2:
raise NotImplementedError()
else:
poly_run = int(sys.argv[1])
name_dir = str(sys.argv[2])
data_dir = config.get_property('data_dir_covid')
geo_dir = config.get_property('geo_dir')
data_dir_mnps = config.get_property('data_dir_col')
results_dir = config.get_property('results_dir')
agglomerated_folder = os.path.join(data_dir, 'data_stages', 'colombia', 'agglomerated', 'geometry' )
data = pd.read_csv(os.path.join(agglomerated_folder, 'cases.csv'), parse_dates=['date_time'], dayfirst=True).set_index('poly_id').loc[poly_run].set_index('date_time')
data = data.resample('D').sum().fillna(0)[['num_cases','num_diseased']]
data = prepare_cases(data, col='num_cases', cutoff=0)
data = prepare_cases(data, col='num_diseased', cutoff=0)
data = data.rename(columns={'smoothed_num_cases': 'confirmed', 'smoothed_num_diseased':'death'})[['confirmed', 'death']]
data = data.iloc[:-14]
model = SEIRD(
confirmed = data['confirmed'].cumsum(),
death = data['death'].cumsum(),
T = len(data),
N = 8181047,
)
T_future = 27
path_to_save = os.path.join(results_dir, 'weekly_forecast' , name_dir, pd.to_datetime(data.index.values[-1]).strftime('%Y-%m-%d'))
def load_samples(filename):
x = np.load(filename, allow_pickle=True)
mcmc_samples = x['mcmc_samples'].item()
post_pred_samples = x['post_pred_samples'].item()
forecast_samples = x['forecast_samples'].item()
return mcmc_samples, post_pred_samples, forecast_samples
mcmc_samples, _, _ = load_samples(os.path.join(path_to_save, 'samples.npz'))
beta = mcmc_samples['beta']
beta_df = create_df_response(beta, beta.shape[-1], date_init ='2020-03-06', quantiles = [50, 80, 95], forecast_horizon=int(7*4), use_future=False)
fig, ax = plt.subplots(1, 1, figsize=(15.5, 7))
ax.plot(beta_df.index.values, beta_df["median"], color='darkred', alpha=0.4, label='Median - Nowcast')
ax.fill_between(beta_df.index.values, beta_df["low_95"], beta_df["high_95"], color='darkred', alpha=0.4, label='95 CI - Nowcast')
ax.fill_between(beta_df.index.values, beta_df["low_80"], beta_df["high_80"], color='darkred', alpha=0.4, label='95 CI - Nowcast')
(y1_l, y2_l) = ax.get_ylim()
# ax.scatter(dates_forecast, median[num_times-1:num_times+num_forecast-1], edgecolor='k', facecolor='white')#, label='Deaths')
ax.xaxis.set_major_locator(mdates.MonthLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b'))
ax.xaxis.set_minor_locator(mdates.DayLocator())
ax.xaxis.set_major_locator(mdates.WeekdayLocator())
ax.xaxis.set_major_locator(mdates.MonthLocator())
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.grid(which='major', axis='y', c='k', alpha=.1, zorder=-2)
ax.tick_params(axis='both', labelsize=15)
ax.yaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.1f}"))
#ax.axvline(x = 37, linestyle='--', label = '{}'.format(dates[-1].strftime('%b-%d')))
ax.set_ylabel(r'$\beta(t)$ - Contact Rate', size=15)
#fig.savefig(os.path.join('figures','mcmc', contact_rate.png'), dpi=300, bbox_inches='tight', transparent=False)
fig.savefig(os.path.join(path_to_save, 'contact_rate.png'), dpi=300, bbox_inches='tight', transparent=False)
plt.close()