-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwaterfall_chart.py
163 lines (124 loc) · 5.48 KB
/
waterfall_chart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
'''
A function that attempts to generate a standard waterfall chart in generic Python. Requires two sequences,
one of labels and one of values, ordered accordingly.
'''
from matplotlib.ticker import FuncFormatter
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.lines as lines
#------------------------------------------
def plot(index, data, Title="", x_lab="", y_lab="",
formatting = "{:,.1f}", green_color='#29EA38', red_color='#FB3C62', blue_color='#24CAFF',
sorted_value = False, threshold=None, other_label='other', net_label='net',
rotation_value = 30, blank_color=(0,0,0,0), figsize = (10,10)):
'''
Given two sequences ordered appropriately, generate a standard waterfall chart.
Optionally modify the title, axis labels, number formatting, bar colors,
increment sorting, and thresholding. Thresholding groups lower magnitude changes
into a combined group to display as a single entity on the chart.
'''
#convert data and index to np.array
index=np.array(index)
data=np.array(data)
# wip
#sorted by absolute value
if sorted_value:
abs_data = abs(data)
data_order = np.argsort(abs_data)[::-1]
data = data[data_order]
index = index[data_order]
#group contributors less than the threshold into 'other'
if threshold:
abs_data = abs(data)
threshold_v = abs_data.max()*threshold
if threshold_v > abs_data.min():
index = np.append(index[abs_data>=threshold_v],other_label)
data = np.append(data[abs_data>=threshold_v],sum(data[abs_data<threshold_v]))
changes = {'amount' : data}
#define format formatter
def money(x, pos):
'The two args are the value and tick position'
return formatting.format(x)
formatter = FuncFormatter(money)
fig, ax = plt.subplots(figsize=figsize)
ax.yaxis.set_major_formatter(formatter)
#Store data and create a blank series to use for the waterfall
trans = pd.DataFrame(data=changes,index=index)
blank = trans.amount.cumsum().shift(1).fillna(0)
trans['positive'] = trans['amount'] > 0
#Get the net total number for the final element in the waterfall
total = trans.sum().amount
trans.loc[net_label]= total
blank.loc[net_label] = total
#The steps graphically show the levels as well as used for label placement
step = blank.reset_index(drop=True).repeat(3).shift(-1)
step[1::3] = np.nan
#When plotting the last element, we want to show the full bar,
#Set the blank to 0
blank.loc[net_label] = 0
#define bar colors for net bar
trans.loc[trans['positive'] > 1, 'positive'] = 99
trans.loc[trans['positive'] < 0, 'positive'] = 99
trans.loc[(trans['positive'] > 0) & (trans['positive'] < 1), 'positive'] = 99
trans['color'] = trans['positive']
trans.loc[trans['positive'] == 1, 'color'] = green_color
trans.loc[trans['positive'] == 0, 'color'] = red_color
trans.loc[trans['positive'] == 99, 'color'] = blue_color
my_colors = list(trans.color)
#Plot and label
my_plot = plt.bar(range(0,len(trans.index)), blank, width=0.5, color=blank_color)
plt.bar(range(0,len(trans.index)), trans.amount, width=0.6,
bottom=blank, color=my_colors)
# connecting lines - figure out later
#my_plot = lines.Line2D(step.index, step.values, color = "gray")
#my_plot = lines.Line2D((3,3), (4,4))
#axis labels
plt.xlabel("\n" + x_lab)
plt.ylabel(y_lab + "\n")
#Get the y-axis position for the labels
y_height = trans.amount.cumsum().shift(1).fillna(0)
temp = list(trans.amount)
# create dynamic chart range
for i in range(len(temp)):
if (i > 0) & (i < (len(temp) - 1)):
temp[i] = temp[i] + temp[i-1]
trans['temp'] = temp
plot_max = trans['temp'].max()
plot_min = trans['temp'].min()
#Make sure the plot doesn't accidentally focus only on the changes in the data
if all(i >= 0 for i in temp):
plot_min = 0
if all(i < 0 for i in temp):
plot_max = 0
if abs(plot_max) >= abs(plot_min):
maxmax = abs(plot_max)
else:
maxmax = abs(plot_min)
pos_offset = maxmax / 40
plot_offset = maxmax / 15 ## needs to me cumulative sum dynamic
#Start label loop
loop = 0
for index, row in trans.iterrows():
# For the last item in the list, we don't want to double count
if row['amount'] == total:
y = y_height[loop]
else:
y = y_height[loop] + row['amount']
# Determine if we want a neg or pos offset
if row['amount'] > 0:
y += (pos_offset*2)
plt.annotate(formatting.format(row['amount']),(loop,y),ha="center", color = green_color, fontsize=9)
else:
y -= (pos_offset*4)
plt.annotate(formatting.format(row['amount']),(loop,y),ha="center", color = red_color, fontsize=9)
loop+=1
#Scale up the y axis so there is room for the labels
plt.ylim(plot_min-round(3.6*plot_offset, 7),plot_max+round(3.6*plot_offset, 7))
#Rotate the labels
plt.xticks(range(0,len(trans)), trans.index, rotation=rotation_value)
#add zero line and title
plt.axhline(0, color='black', linewidth = 0.6, linestyle="dashed")
plt.title(Title)
plt.tight_layout()
return plt