-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdistance_functions.py
95 lines (68 loc) · 2.88 KB
/
distance_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
from scipy.stats import wasserstein_distance
import numpy as np
import math
import lpips
loss_fn_vgg = lpips.LPIPS(net='vgg').to('cuda') # closer to "traditional" perceptual loss, when used for optimization
def normalize_in_range(batch, range_max=1, range_min=-1):
batch_flat = torch.flatten(batch, start_dim=1)
batch_min = torch.min(batch_flat, dim=1)[0]
batch_min = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(batch_min, dim=1), dim=1), dim=1)
batch = batch - batch_min
batch_flat = torch.flatten(batch, start_dim=1)
batch_max = torch.max(batch_flat, dim=1)[0]
batch_max = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(batch_max, dim=1), dim=1), dim=1)
batch = batch / batch_max
batch = (batch * (range_max - range_min)) + range_min
return batch
def perceptual_dist(batch1, batch2):
batch1 = normalize_in_range(batch1)
batch2 = normalize_in_range(batch2)
dist = loss_fn_vgg(batch1, batch2)
dist = torch.squeeze(dist)
dist = dist.cpu().numpy()
return dist
def angular_distance(batch1, batch2):
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
b1_flat = torch.flatten(batch1, start_dim=1)
b2_flat = torch.flatten(batch2, start_dim=1)
sim = cos(b1_flat, b2_flat)
dist = torch.acos(sim) / math.pi
dist = dist.cpu().numpy()
return dist
def total_variation(batch: torch.tensor, p: int = 1):
# batch: torch tensor containing a batch of images idexed as batch_size * color_channel * height * width
diff1 = torch.flatten(batch[:, :, 1:, :] - batch[:, :, :-1, :], start_dim=2)
diff2 = torch.flatten(batch[:, :, :, 1:] - batch[:, :, :, :-1], start_dim=2)
diff = torch.cat((diff1, diff2), dim=2)
dist = torch.sum(torch.norm(diff, p=p, dim=1), dim=1)
dist = dist.cpu().numpy()
return dist
def tv1_diff(batch1: torch.tensor, batch2: torch.tensor):
return total_variation(batch1 - batch2, 1)
def tv2_diff(batch1: torch.tensor, batch2: torch.tensor):
return total_variation(batch1 - batch2, 2)
####### Jaccard Distance ########
def rect_intersection(r1, r2):
r3 = torch.zeros_like(r1)
r3[:, [0, 1]] = torch.max(r1[:, [0, 1]], r2[:, [0, 1]])
r3[:, [2, 3]] = torch.min(r1[:, [2, 3]], r2[:, [2, 3]])
r3[(r3[:, 0] > r3[:, 2]) | (r3[:, 1] > r3[:, 3]), :] = -1.0
return r3
def rect_area(r):
return (r[:, 2] - r[:, 0]) * (r[:, 3] - r[:, 1])
def iou(r1, r2):
intersection = rect_intersection(r1, r2)
area_intersection = rect_area(intersection)
area_union = rect_area(r1) + rect_area(r2) - area_intersection
return area_intersection/area_union
def jaccard_dist(batch1, batch2):
dist = 1 - iou(batch1, batch2)
dist = dist.cpu().numpy()
return dist
######## L2 distance ##########
def l2_dist(batch1, batch2):
dist = torch.norm(torch.flatten(batch1 - batch2, start_dim=1), p=2, dim=1)
dist = dist.cpu().numpy()
return dist