-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_util.py
259 lines (191 loc) · 8.19 KB
/
plot_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
ETHREE_COL = [(3 / 255, 78 / 255, 110 / 255), (175 / 255, 126 / 255, 0),
(175 / 255, 34 / 255, 0 / 255), (0 / 255, 126 / 255, 51 / 255),
(175 / 255, 93 / 255, 0 / 255), (10 / 255, 25 / 255, 120 / 255),
(52 / 255, 157 / 255, 202 / 255), (255 / 255, 199 / 255, 57 / 255),
(255 / 255, 95 / 255, 57 / 255), (48 / 255, 215 / 255, 15 / 255),
(255 / 255, 162 / 255, 57 / 255), (68 / 255, 88 / 255, 210 / 255)
]
PRECISION = 1.e-10
def nanmap(val, fillval=0):
if np.isnan(val):
return fillval
else:
return val
def reverse_nanmap(val, tofill=0.):
if val == tofill:
return np.nan
else:
return val
def nonzero_map(val):
if val == 0.:
return PRECISION
else:
return val
def othermap(val, keys, other_key='Other'):
if val in keys:
return val
else:
return other_key
def stacked_area(data, case, varname, output_directory, index_name, fmt='pdf', keys=None, labels_dict=None,
color_dict=None, scaling=1, yrange=None, ylabel='', case_index='Active_Cases',
time_index='Output_Year', value_name='Value', aggfunc=np.sum, map=None, fontsize=12,
xlim=(2015, 2050), xlabel='', select=None, other_key=None, legend=True, title='',
filename=''): #select is a dictionary of lists
print('Stacked area chart for variable: ' + varname + ', case: ' + case)
data = data.copy()
var = data[data[case_index] == case]
if labels_dict is not None:
var[index_name] = var[index_name].map(labels_dict)
if other_key is not None:
var[index_name] = var[index_name].map(lambda x: othermap(x, keys, other_key))
var[value_name] = var[value_name].map(nanmap)
print(var.head())
if select is None:
pivot = pd.crosstab(var[time_index], var[index_name], values=var[value_name], aggfunc=aggfunc)
else:
index = list(select.keys())
index.append(time_index)
pivot = pd.pivot_table(var, index=index, columns=index_name, values=value_name, aggfunc=aggfunc, fill_value=0.)
idx = pd.IndexSlice
for column in select.keys():
pivot = pivot.loc[idx[select[column]], idx[:]]
pivot = pivot.groupby(time_index).sum()
print(pivot.columns.tolist())
if keys is None:
active_list = pivot.columns.tolist()
else:
active_list = keys.copy()
for key in keys:
if key not in pivot.columns.tolist() or max(abs(pivot[key])) == 0.:
active_list.remove(key)
pivot = pivot[active_list]
if map is None:
map = lambda x: abs(x) * scaling
pivot = pivot.applymap(map)
fig, ax = plt.subplots() ## this is the most flexible approach to access the mpl api
## can easily do subplots this way
if color_dict is None:
color = ETHREE_COL
else:
color = [color_dict[x] for x in pivot.columns.values]
# Attempt to avoid residual lines in stacked area charts
#pivot = pivot.applymap(nonzero_map)
pivot.plot.area(ax=ax, fontsize=fontsize, color=color, lw=0, legend=legend) ## list comprehension using color dict above
# plt.stackplot(gen.index, gen.values.transpose())
if yrange is not None:
ax.set_ylim(yrange)
ax.set_ylabel(ylabel, fontsize=fontsize)
if xlim is not None:
ax.set_xlim(xlim)
ax.set_xlabel(xlabel)
handles, labels = ax.get_legend_handles_labels() ## get legend labels and boxes as variables
if legend:
lgd = ax.legend(handles=handles[::-1], labels=labels[::-1], bbox_to_anchor=[1, 1], fontsize=fontsize) ## reverse order
artists = (lgd,)
else:
artists = ()
## bbox moves the legend in more precise fashion
if title == '':
title += case
ax.set_title(title, fontsize=fontsize + 2)
if filename != '':
filename += '_'
filename += varname + '_' + case + '.' + fmt
plt.savefig(os.path.join(output_directory, filename),
dpi=600, transparent=False, bbox_extra_artists=artists,
bbox_inches='tight', format=fmt)
return pivot
# plt.close()
def reshape(df, xkeys, ykeys):
target = df
if ykeys is not None:
active_list_y = ykeys.copy()
for key in ykeys:
if key not in df.columns.tolist():
print('Key not found: ' + key + ': removing.')
active_list_y.remove(key)
target = target[active_list_y]
if xkeys is None:
active_list = df.index.tolist()
else:
active_list = xkeys.copy()
for key in xkeys:
if key not in df.index.tolist():
print('Key not found: ' + key + ': removing.')
active_list.remove(key)
target = target.loc[active_list]
return target, active_list
def safe_dict(x, dictionary):
try:
y = dictionary[x]
except KeyError:
y = x
return y
def stacked_bar(data, select, varname, output_directory, index_name, fmt='pdf', xkeys=None, ykeys=None, labels_dict=None,
color_dict=None, scaling=1, yrange=None, ylabel='', case_index='Active_Cases',
value_name='Value', aggfunc=np.sum, map=None, fontsize=12, time_index='Output_Year',
xlabel='', title='', xlabels=None, base_case=None, other_key=None, filename=None):
print('Stacked bar chart for variable: ' + varname + ', year: ' + str(select))
data = data.copy()
#var = data[data[time_index] == select]
var = data
var[value_name].map(nanmap)
if labels_dict is not None:
var[index_name] = var[index_name].map(lambda x: safe_dict(x, labels_dict))
if other_key is not None:
var[index_name] = var[index_name].map(lambda x: othermap(x, ykeys, other_key))
pivot = var.pivot_table(index=[time_index, case_index], columns=index_name, values=value_name, aggfunc=aggfunc)
columns = pivot.columns.tolist()
print(columns)
cases = list(pivot.index.levels[1])
print(cases)
pivot = pivot.loc[select]
target, active_list = reshape(pivot, xkeys, ykeys)
if base_case is not None:
for case in active_list:
target.loc[case] = target.loc[case].subtract(pivot.loc[base_case], fill_value=0.)
if map is None:
map = lambda x: abs(x) * scaling
target = target.applymap(map)
fig, ax = plt.subplots() ## this is the most flexible approach to access the mpl api
## can easily do subplots this way
if color_dict is None:
color = ETHREE_COL
else:
color = [color_dict[x] for x in target.columns.values]
target.plot.bar(ax=ax, fontsize=fontsize, color=color, stacked=True) ## list comprehension using color dict above
if yrange is not None:
ax.set_ylim(yrange)
ax.set_ylabel(ylabel, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
if xlabels is not None:
if max([len(x) for x in xlabels]) > 25 and fontsize is not None:
xlfontsize = fontsize - 3
else:
xlfontsize = fontsize
ax.set_xticklabels(xlabels, rotation=0, fontsize=xlfontsize)
handles, labels = ax.get_legend_handles_labels() ## get legend labels and boxes as variables
# if labels_dict is not None:
# labels = [labels_dict[x] for x in active_list_y]
lgd = ax.legend(handles=handles[::-1], labels=labels[::-1], bbox_to_anchor=[1, 1], fontsize=fontsize) ## reverse order
## bbox moves the legend in more precise fashion
ax.set_title(title, fontsize=fontsize + 2)
if filename is None:
filename = title + '_' + str(select) + '.' + fmt
else:
filename += '.' + fmt
plt.savefig(os.path.join(output_directory, filename),
dpi=600, transparent=False, bbox_extra_artists=(lgd,),
bbox_inches='tight', format=fmt)
def shade(ax, group_size, num_bars):
for i in np.arange(-0.5, -0.5 + num_bars, group_size):
if (i + 0.5) / group_size % 2 == 0:
ax.fill_betweenx(np.arange(ax.get_ylim()[0], ax.get_ylim()[1]), i, i + group_size, zorder=0,
facecolor='gainsboro', alpha=1)
else:
pass
return ax