-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathPFV.py
63 lines (43 loc) · 2.25 KB
/
PFV.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import numpy as np
import torch
import torch.nn.functional as F
def normalize_and_scale_features(features, n_sigma=1):
# Whiten
scaled_features = (features - np.mean(features)) / (np.std(features) )
# Clip within n standard deviation
scaled_features = np.clip(scaled_features, -n_sigma, n_sigma)
# Rescale to [0,1] **This gives a batch scaling**
scaled_features = (scaled_features - scaled_features.min()) / (scaled_features.max()-scaled_features.min())
return scaled_features
def pca_decomposition(x, n_components=3):
feats = x.permute(0,2,3,1).reshape(-1, x.shape[1])
feats = feats-torch.mean(feats,0)
u,s,v = torch.svd(feats, compute_uv=True)
pc = torch.matmul(u[:,:n_components], torch.diag(s[:n_components]))
pc = pc.view(x.shape[0], x.shape[2], x.shape[3], 3).permute(0,3,1,2)
return pc
# Normalized activation maps
def feature_map_normalization(f):
act_map = torch.sum(f, dim=1).unsqueeze(1)
act_map /= act_map.max()
return act_map
# Principle feature visulization
def pfv(embeddings, image_shape=None, idx_layer=None, hierarchical=False, interp_mode='bilinear'):
if image_shape is None: image_shape = embeddings[0].shape[-2:]
if idx_layer is None: idx_layer = len(embeddings)
with torch.no_grad():
# Decompose to principle contrasting features
layer_to_visualize = pca_decomposition(embeddings[idx_layer], 3)
# Weighted upsampling with activation maps
if hierarchical:
for f in reversed(embeddings[:idx_layer]):
layer_to_visualize = F.interpolate(layer_to_visualize, size=(f.shape[2], f.shape[3]), mode=interp_mode)
layer_to_visualize *= feature_map_normalization(f)
else:
amap = [F.interpolate(torch.sum(x,dim=1).unsqueeze(1), size=image_shape,mode=interp_mode) for x in embeddings[:idx_layer]]
amap = torch.cat(amap, dim=1)
layer_to_visualize = F.interpolate(layer_to_visualize, size=image_shape,mode=interp_mode) * torch.sum(amap,dim=1).unsqueeze(1)
# Normalize response to RGB
layer_to_visualize = layer_to_visualize.detach().cpu().numpy()
rgb = normalize_and_scale_features(layer_to_visualize)
return rgb