-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathActorObserverLoss.py
70 lines (58 loc) · 2.47 KB
/
ActorObserverLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from torch import nn
from models.layers.EqualizeGradNorm import EqualizeGradNorm
from models.layers.VideoSoftmax import VideoSoftmax
from models.layers.MarginRank import MarginRank
from models.layers.DistRatio import DistRatio
from models.layers.BlockGradient import BlockGradient
VERBOSE = True
def dprint(message, *args):
if VERBOSE:
print(message.format(*args))
class ActorObserverLoss(nn.Module):
def __init__(self, args):
super(ActorObserverLoss, self).__init__()
self.loss = globals()[args.subloss]
self.xstorage = {}
self.ystorage = {}
self.zstorage = {}
self.storage = {}
self.decay = args.finaldecay
self.xmax = VideoSoftmax(self.xstorage, args.decay)
self.ymax = VideoSoftmax(self.ystorage, args.decay)
self.zmax = VideoSoftmax(self.zstorage, args.decay)
self.margin = args.margin
def get_constants(self, ids):
out = [self.storage[x][0] for x in ids]
return torch.autograd.Variable(torch.Tensor(out).cuda())
def update_constants(self, input, weights, ids):
for x, w, vid in zip(input, weights, ids):
x, w = x.data[0], w.data[0]
if vid not in self.storage:
self.storage[vid] = [x, w]
else:
# here J is stored as E[wJ]
old_x, old_w = self.storage[vid]
val = (1 - self.decay) * w * x + self.decay * old_w * old_x
new_weight = (1 - self.decay) * w + self.decay * old_w
val = val / new_weight
self.storage[vid] = [val, new_weight]
if new_weight < 0.0001:
print('MILC new_weight is effectively 0')
def forward(self, dist_a, dist_b, x, y, z, target, ids):
# Normalize and combine weights
x = self.xmax(x, ids)
y = self.ymax(y, ids)
z = self.zmax(z, ids)
dist_a, dist_b, x, y, z = EqualizeGradNorm.apply(dist_a, dist_b, x, y, z)
w = x * y * z
# update L
loss = self.loss.apply(dist_a, dist_b, target, self.margin)
self.update_constants(loss, w, ids)
k = self.get_constants(ids)
n = (w.sum() + 0.00001) / w.shape[0]
final = ((loss - k) * (w / n)).sum()
dprint('loss before {}', loss.data.sum())
dprint('loss after {}', (loss.data * w.data / n.data).sum())
dprint('weight median: {}, var: {}', w.data.median(), w.data.var())
return final, w.data.cpu()