-
Notifications
You must be signed in to change notification settings - Fork 4
/
drift_magnitude.py
147 lines (113 loc) · 4.83 KB
/
drift_magnitude.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
""" Functions to compute concept drift magnitude
- lga"""
import numpy as np
# from scipy.special import rel_entr
from scipy.stats import entropy
def total_variation(pk, qk, axis=0):
""" Computes total variation distance (Levin 2008) as measure for concept drift (Webb 2018) """
# Cast to numpy ndarray
pk = np.asarray(pk)
qk = np.asarray(qk)
if pk.shape != qk.shape:
raise ValueError("pk and qk be must have same shapes")
# Compute total variation distance
tv = 0.5 * np.sum(np.abs(pk - qk), axis=axis)
return tv
def kl_divergence(pk, qk, axis=0):
""" Computes total variation distance (Levin 2008) as measure for concept drift (Webb 2018) """
# Cast to numpy ndarray
pk = np.asarray(pk)
qk = np.asarray(qk)
if pk.shape != qk.shape:
raise ValueError("pk and qk be must have same shapes")
# Compute total variation distance
return entropy(pk, qk, axis=axis)
VALID_METRICS = {'total_variation': total_variation,
'kl_divergence': kl_divergence}
def drift_magnitude(pk, qk, metric='total_variation', axis=0, **kwargs):
"""
Computes the drift magnitude between pk and qk according to `metric`.
Keyword arguments are passed to the metric function
"""
# Cast to numpy ndarray
pk = np.asarray(pk)
qk = np.asarray(qk)
# Normalize
pk = 1.0 * pk / np.sum(pk, axis=axis, keepdims=True)
qk = 1.0 * qk / np.sum(qk, axis=axis, keepdims=True)
if callable(metric):
return metric(pk, qk, **kwargs)
return VALID_METRICS[metric](pk, qk, **kwargs)
def drift_magnitude_per_time(time, labels, t_start=None, history=None, cumulative=False, verbose=False,
metric='total_variation'):
time = np.asarray(time)
labels = np.asarray(labels)
assert time.shape == labels.shape
assert history is None or history > 0, "History for previous data is exclusive, min 1"
num_labels = len(np.unique(labels))
if verbose:
print("Found", num_labels, "globally unique label")
steps = np.unique(time)
if verbose:
print("Found time steps between", steps[0], "and", steps[-1])
if t_start is None:
# If not given, start with step 1 (rather than 0)
t_start = steps[1]
eval_steps = steps[steps >= t_start]
if verbose:
print("Using time steps between", eval_steps[0], "and", eval_steps[-1])
drift = []
for eval_step in eval_steps:
if history is None:
# Assume infinite history
previous_labels = labels[time < eval_step]
else:
# Use given history to select prev labels
previous_labels = labels[(time < eval_step) & (time >= (eval_step - history))]
# New data
if cumulative:
# Add history window also to current labels
current_labels = labels[(time <= eval_step) & (time >= (eval_step - history))]
else:
current_labels = labels[time == eval_step]
# Count labels
previous_unique, previous_counts = np.unique(previous_labels, return_counts=True)
current_unique, current_counts = np.unique(current_labels, return_counts=True)
# Put counts into blank labels array (shapes must match!)
pk = np.zeros(num_labels)
pk[previous_unique] = previous_counts
qk = np.zeros(num_labels)
qk[current_unique] = current_counts
# Compute drift magnitude
dm = drift_magnitude(pk, qk, metric=metric)
print(eval_step, ',', dm, sep='')
drift.append(dm)
return eval_steps, drift
def main():
import argparse
import seaborn
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('time', help="Path to numpy file with times")
parser.add_argument('labels', help="Path to numpy file with labels")
parser.add_argument('--t_start', default=None, type=int, help="Start time")
parser.add_argument('--history', default=None, type=int, help="Sliding window size for previous labels")
parser.add_argument('--cumulative', default=False, action='store_true', help="Apply history also to 'right-side'")
parser.add_argument('--save_plot', default=None, type=str, help="Path to save plot")
parser.add_argument('--info', default=False, action='store_true', help="Print some more info")
args = parser.parse_args()
verbose = args.info
time = np.load(args.time)
labels = np.load(args.labels)
if verbose:
print("Time shape", time.shape, "dtype", time.dtype)
print("Labels shape", labels.shape, "dtype", labels.dtype)
x, y = drift_magnitude_per_time(time, labels, t_start=args.t_start, verbose=verbose,
history=args.history, cumulative=args.cumulative)
seaborn.lineplot(x=x,y=y)
if args.save_plot:
plt.savefig(args.save_plot)
else:
plt.show()
if __name__ == "__main__":
main()