This repository has been archived by the owner on May 16, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
metrics.py
70 lines (55 loc) · 2.47 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
A collection of functions for recording the data from experiments.
"""
import jax
import jax.numpy as jnp
def measurer(net):
"""Get a dictionary containing the metric functions to record data with"""
@jax.jit
def accuracy(params, X, y):
predictions = net.apply(params, X)
return jnp.mean(jnp.argmax(predictions, axis=-1) == y)
@jax.jit
def attack_success_rate(params, X, y, attack_from, attack_to):
preds = jnp.argmax(net.apply(params, X), axis=-1)
idx = y == attack_from
return jnp.sum(jnp.where(idx, preds, -1) == attack_to) / jnp.sum(idx)
return {'acc': accuracy, 'asr': attack_success_rate}
def create_recorder(evals, train=False, test=False, add_evals=None):
"""Create a structured dictionary to record data into"""
results = dict()
if train:
results.update({f"train {e}": [] for e in evals})
if test:
results.update({f"test {e}": [] for e in evals})
if add_evals is not None:
results.update({e: [] for e in add_evals})
return results
def record(results, evaluator, params, train_ds=None, test_ds=None, add_recs=None, **kwargs):
"""Record a line of data"""
for k, v in results.items():
ds = train_ds if "train" in k else test_ds
if "acc" in k:
v.append(evaluator['acc'](params, *next(ds)))
if ("test" in k or "train" in k) and ("asr" in k):
v.append(evaluator['asr'](params, *next(ds), kwargs['attack_from'], kwargs['attack_to']))
if add_recs is not None:
for k, v in add_recs.items():
results[k].append(v)
def finalize(results):
"""Format the recorded metrics into a useful data type"""
for k, v in results.items():
results[k] = jnp.array(v)
return results
def tabulate(results, total_rounds, ri=10):
"""Create string showing a table of results from the recorded metrics"""
halftime = int((total_rounds / 2) / ri)
table = ""
for k, v in results.items():
table += f"[{k}] mean: {v.mean()}, std: {v.std()} [after {halftime * ri} rounds] mean {v[halftime:].mean()}, std: {v[halftime:].std()}\n"
return table[:-1]
def csvline(ds_name, alg, adv, results, total_rounds, ri=10):
"""Summarize the recorded results into a single csv formatted line"""
halftime = int((total_rounds / 2) / ri)
asr = results['test asr']
return f"{ds_name},{alg},{adv:.2%},{results['test accuracy'][-1]},{asr.mean()},{asr.std()},{asr[halftime:].mean()},{asr[halftime:].std()}\n"