-
Notifications
You must be signed in to change notification settings - Fork 0
/
distances.py
83 lines (62 loc) · 2.5 KB
/
distances.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from abc import ABC, abstractmethod
from typing import TypeVar
import eagerpy as ep
def flatten(x: ep.Tensor, keep: int = 1) -> ep.Tensor:
return x.flatten(start=keep)
def atleast_kd(x: ep.Tensor, k: int) -> ep.Tensor:
shape = x.shape + (1,) * (k - x.ndim)
return x.reshape(shape)
T = TypeVar("T")
class Distance(ABC):
@abstractmethod
def __call__(self, reference: T, perturbed: T) -> T:
...
@abstractmethod
def clip_perturbation(self, references: T, perturbed: T, epsilon: float) -> T:
...
class LpDistance(Distance):
def __init__(self, p: float):
self.p = p
def __repr__(self) -> str:
return f"LpDistance({self.p})"
def __str__(self) -> str:
return f"L{self.p} distance"
def __call__(self, references: T, perturbed: T) -> T:
"""Calculates the distances from references to perturbed using the Lp norm.
Args:
references: A batch of reference inputs.
perturbed: A batch of perturbed inputs.
Returns:
A 1D tensor with the distances from references to perturbed.
"""
(x, y), restore_type = ep.astensors_(references, perturbed)
norms = ep.norms.lp(flatten(y - x), self.p, axis=-1)
return restore_type(norms)
def clip_perturbation(self, references: T, perturbed: T, epsilon: float) -> T:
"""Clips the perturbations to epsilon and returns the new perturbed
Args:
references: A batch of reference inputs.
perturbed: A batch of perturbed inputs.
Returns:
A tenosr like perturbed but with the perturbation clipped to epsilon.
"""
(x, y), restore_type = ep.astensors_(references, perturbed)
p = y - x
if self.p == ep.inf:
clipped_perturbation = ep.clip(p, -epsilon, epsilon)
return restore_type(x + clipped_perturbation)
norms = ep.norms.lp(flatten(p), self.p, axis=-1)
norms = ep.maximum(norms, 1e-12) # avoid divsion by zero
factor = epsilon / norms
factor = ep.minimum(1, factor) # clipping -> decreasing but not increasing
if self.p == 0:
if (factor == 1).all():
return perturbed
raise NotImplementedError("reducing L0 norms not yet supported")
factor = atleast_kd(factor, x.ndim)
clipped_perturbation = factor * p
return restore_type(x + clipped_perturbation)
l0 = LpDistance(0)
l1 = LpDistance(1)
l2 = LpDistance(2)
linf = LpDistance(ep.inf)