-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
94 lines (86 loc) · 15.9 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import time
import numpy as np
import os.path as osp
import torch
import pickle
hoi_weight = np.array([39.86157711904592, 42.40444304026247, 44.834823527125415, 39.73991804849558, 39.350051894572104, 43.56609116327042, 37.78352039671705, 37.27879696661561, 56.2961038839078, 27.187192997462514, 48.69942543701149, 29.31727479695975, 40.91791293317506, 41.48885009402292, 48.00306615559755, 50.012214583404685, 41.67212390491823, 42.921511271001236, 28.880648716145704, 29.52003435670287, 28.7948786160738, 39.76397874615436, 59.30640384054761, 16.891810699279997, 45.78457865943398, 38.21737256387448, 41.27866658762785, 55.327003753827235, 50.27550397062817, 35.892058594766205, 23.930101537298007, 42.1883315501357, 39.01256606369551, 46.40605772692243, 41.63484517972581, 51.90277694560517, 39.86157711904592, 44.3232983026516, 26.24108136535154, 33.525745002186696, 36.51886783101932, 28.541819963426093, 30.205498384606926, 45.882177032325544, 59.30640384054761, 19.277187457752916, 40.44149658882279, 40.27550397062818, 34.99276619895774, 45.24100203620805, 53.86572339704485, 50.012214583404685, 48.69942543701149, 17.603639967967627, 38.389734264590764, 54.535191293350984, 35.660893887007894, 47.692723818197855, 48.00306615559755, 43.80412030999668, 30.501126058559556, 32.35158707564564, 54.535191293350984, 62.31670379718742, 23.78093836099101, 48.00306615559755, 62.31670379718742, 36.67004315466653, 44.535191293350984, 42.58542526119043, 62.31670379718742, 48.892476988965356, 48.16697031747924, 34.600828988374865, 39.88632331032448, 14.392366637557455, 62.31670379718742, 56.2961038839078, 40.3301329276432, 43.50856787437951, 54.535191293350984, 43.86572339704485, 45.414742996902284, 57.545491249990796, 62.31670379718742, 34.76558113323671, 45.15667036083943, 41.8245235704856, 43.285803927267985, 31.29236674037406, 62.31670379718742, 13.477875959525374, 47.84512348376523, 38.73735532718288, 32.71199602184443, 27.673298950910752, 52.31670379718742, 38.425042953542096, 42.49399146679174, 55.327003753827235, 55.327003753827235, 45.15667036083943, 47.26520401398836, 45.414742996902284, 54.535191293350984, 41.27866658762785, 26.897169052605058, 55.327003753827235, 28.62825872892921, 27.253006626232384, 19.566586419407308, 42.44898645452497, 59.30640384054761, 51.52489133671117, 51.90277694560517, 37.18452779650803, 48.16697031747924, 39.86157711904592, 49.76397874615436, 49.30640384054761, 40.22155365176111, 49.30640384054761, 43.12592287342668, 50.27550397062817, 45.327003753827235, 41.244604100708735, 48.33730371046705, 57.545491249990796, 25.50972350021107, 48.69942543701149, 46.753678789514545, 33.13115849168469, 48.16697031747924, 40.824512670633624, 51.90277694560517, 62.31670379718742, 53.285803927267985, 42.97171928475174, 38.794878616073795, 30.481158460998806, 35.24100203620806, 31.561234183262115, 42.8228037307383, 42.31670379718742, 48.16697031747924, 21.951612511860127, 32.40887687915604, 43.92821288981487, 41.38248694556508, 57.545491249990796, 42.06364514453972, 51.52489133671117, 35.81362856586806, 30.962196803732283, 30.042979374291058, 30.29727316317092, 38.200506737555116, 50.55579120663061, 57.545491249990796, 19.54955885088451, 42.921511271001236, 49.529167787659134, 35.17340619973509, 38.13369088398997, 44.39278690220489, 55.327003753827235, 62.31670379718742, 40.55579120663061, 53.86572339704485, 2.660919516910389, 52.31670379718742, 52.31670379718742, 52.77427870279417, 22.931506545422504, 49.09451084984823, 42.631874311648076, 34.750342714728944, 39.88632331032448, 47.26520401398836, 59.30640384054761, 48.16697031747924, 62.31670379718742, 41.863474009320846, 41.21080669419493, 55.327003753827235, 25.14999404158607, 41.863474009320846, 41.59788372412617, 56.2961038839078, 59.30640384054761, 34.66001824959728, 37.45948953237162, 53.285803927267985, 28.543641286505434, 47.4030868588447, 57.545491249990796, 38.794878616073795, 26.440716499974968, 53.86572339704485, 49.76397874615436, 45.327003753827235, 38.83365516670581, 47.692723818197855, 48.33730371046705, 47.84512348376523, 59.30640384054761, 59.30640384054761, 27.72277891959511, 32.56698385420673, 43.285803927267985, 49.30640384054761, 50.27550397062817, 32.22644637631832, 19.113163469010704, 52.77427870279417, 37.17122627058457, 56.2961038839078, 41.67212390491823, 37.292432597343094, 39.26319010272118, 51.90277694560517, 48.16697031747924, 57.545491249990796, 25.21383732015851, 45.327003753827235, 43.12592287342668, 33.7494148933586, 62.31670379718742, 37.05331102328898, 59.30640384054761, 34.37182333059573, 24.936036649412728, 39.350051894572104, 39.32817303309036, 24.636245656163254, 49.09451084984823, 44.46340544707976, 42.31670379718742, 40.793820353356864, 59.30640384054761, 37.157965360070634, 40.44149658882279, 32.457950224103485, 51.17727027411905, 42.921511271001236, 29.26534060775103, 20.70871826180388, 41.27866658762785, 33.15743168021626, 34.81934064149681, 33.8906114010818, 19.271367693122244, 43.56609116327042, 47.00191462676487, 57.545491249990796, 59.30640384054761, 19.587154348705695, 62.31670379718742, 45.78457865943398, 45.78457865943398, 62.31670379718742, 59.30640384054761, 62.31670379718742, 34.842585718323186, 37.2516534731387, 42.23070207956824, 37.51663436761591, 43.56609116327042, 40.763343422536806, 47.545491249990796, 51.90277694560517, 43.50856787437951, 25.56809238980931, 42.23070207956824, 56.2961038839078, 48.16697031747924, 46.51886783101932, 39.37204153557148, 48.892476988965356, 62.31670379718742, 62.31670379718742, 57.545491249990796, 30.025006771796413, 39.692192899883125, 35.00888104052353, 39.57512530455062, 62.31670379718742, 38.606025174470055, 41.143990840629776, 23.637317288279576, 43.92821288981487, 45.882177032325544, 54.535191293350984, 51.90277694560517, 27.024967764570196, 40.88655579464647, 41.70972539365131, 37.375157857002996, 50.27550397062817, 50.55579120663061, 43.231853608400925, 46.876023353684666, 45.24100203620805, 62.31670379718742, 16.424023402560195, 52.31670379718742, 44.68242386155804, 41.41765268279344, 37.95507732677986, 45.882177032325544, 43.12592287342668, 62.31670379718742, 27.177871941076496, 46.08421089320842, 42.02286602033532, 59.30640384054761, 51.17727027411905, 59.30640384054761, 50.85542344040504, 43.340432884283004, 35.679694543290935, 39.46113070710968, 44.99276619895774, 34.09502300350725, 28.227523588719624, 62.31670379718742, 47.84512348376523, 50.55579120663061, 52.77427870279417, 32.15891623329701, 40.85542344040504, 41.863474009320846, 47.692723818197855, 54.535191293350984, 59.30640384054761, 25.613317385913, 42.44898645452497, 37.22467857387639, 35.8624811036965, 50.012214583404685, 38.532724787706044, 25.93880550356513, 40.19482775314784, 51.52489133671117, 48.69942543701149, 59.30640384054761, 36.15720328062341, 33.27496011434579, 45.414742996902284, 45.327003753827235, 59.30640384054761, 48.514591380071366, 43.50856787437951, 40.38545781364281, 54.535191293350984, 17.279205413234312, 40.67317523934305, 41.67212390491823, 62.31670379718742, 42.10481080648804, 38.87278106033631, 45.327003753827235, 30.602364787757338, 45.882177032325544, 57.545491249990796, 45.50429142343155, 38.932138861141375, 30.87719263294779, 45.982019241391555, 38.91226264878624, 33.71931813521595, 35.00081614532003, 41.8245235704856, 40.498267917739696, 35.8624811036965, 20.759093668408187, 39.09451084984823, 34.73515757751352, 46.2961038839078, 62.31670379718742, 34.66001824959728, 55.327003753827235, 23.89185955307172, 38.73735532718288, 46.753678789514545, 34.37879995027924, 50.85542344040504, 38.407352726153626, 25.571766624223923, 62.31670379718742, 55.327003753827235, 53.86572339704485, 37.2516534731387, 45.59572521783024, 44.834823527125415, 62.31670379718742, 33.41249360917828, 62.31670379718742, 57.545491249990796, 57.545491249990796, 46.63468655651747, 62.31670379718742, 62.31670379718742, 57.545491249990796, 53.86572339704485, 59.30640384054761, 28.45349805824696, 56.2961038839078, 51.17727027411905, 43.231853608400925, 55.327003753827235, 49.76397874615436, 48.69942543701149, 33.27496011434579, 46.753678789514545, 48.514591380071366, 62.31670379718742, 31.942438817781188, 62.31670379718742, 43.12592287342668, 46.753678789514545, 45.414742996902284, 41.38248694556508, 39.668525567092054, 44.60818368076598, 45.15667036083943, 55.327003753827235, 62.31670379718742, 26.128903552125276, 59.30640384054761, 46.51886783101932, 56.2961038839078, 48.892476988965356, 33.22114350477567, 39.621574355008256, 44.39278690220489, 57.545491249990796, 29.10279101407053, 47.545491249990796, 62.31670379718742, 57.545491249990796, 41.52489133671117, 39.37204153557148, 51.52489133671117, 31.147307331679862, 40.61408664323785, 38.389734264590764, 42.8228037307383, 35.92183890450156, 59.30640384054761, 51.52489133671117, 52.77427870279417, 33.01740819634154, 40.94949812562335, 35.93181122764105, 38.200506737555116, 36.188865229990064, 32.39558891931792, 44.68242386155804, 30.061026662792713, 48.892476988965356, 30.590674485088822, 21.711991869210635, 55.327003753827235, 44.12126444176874, 42.77427870279417, 49.09451084984823, 41.38248694556508, 47.692723818197855, 55.327003753827235, 33.95979808226316, 32.813055253426185, 32.26060934358462, 24.087181391712605, 55.327003753827235, 46.51886783101932, 44.46340544707976, 40.35770727309509, 36.814420266636475, 34.316410204746084, 34.1808939115055, 33.73735114999313, 27.336356560317153, 38.442805533800126, 42.631874311648076, 62.31670379718742, 47.26520401398836, 25.597573672771553, 44.535191293350984, 39.78817348738849, 42.72628987397649, 41.143990840629776, 38.62454522308599, 37.0147068151566, 41.863474009320846, 37.55999191394312, 51.90277694560517, 43.99161467012506, 62.31670379718742, 62.31670379718742, 40.793820353356864, 22.874345359252622, 48.892476988965356, 52.31670379718742, 62.31670379718742, 32.133860712922115, 38.68058399826598, 41.70972539365131, 38.5509342266223, 62.31670379718742, 52.31670379718742, 51.90277694560517, 47.545491249990796, 47.545491249990796, 62.31670379718742, 24.51353067578591, 38.73735532718288, 54.535191293350984, 37.347407316455275, 50.55579120663061, 57.545491249990796, 47.692723818197855, 59.30640384054761, 38.87278106033631, 40.44149658882279, 48.33730371046705, 62.31670379718742, 26.219692773393426, 45.59572521783024, 37.90761297653525, 44.834823527125415, 57.545491249990796, 29.810064602554988, 39.30640384054761, 38.46064106120429, 62.31670379718742, 43.02251454004449, 28.545463373722857, 44.75795524046251, 62.31670379718742, 43.12592287342668, 51.17727027411905, 47.545491249990796, 47.84512348376523, 33.17327222599302, 18.763338182063613, 59.30640384054761, 59.30640384054761, 62.31670379718742, 46.51886783101932, 53.86572339704485, 62.31670379718742, 57.545491249990796, 50.55579120663061, 44.834823527125415, 57.545491249990796, 55.327003753827235, 35.577283810846545, 38.661823948278425, 37.26520401398836, 62.31670379718742, 31.974431189481912, 51.17727027411905, 50.012214583404685, 49.09451084984823, 48.514591380071366, 21.94044709804023, 48.514591380071366, 40.643530449705665, 45.414742996902284, 41.902776945605176, 49.30640384054761, 37.87625583800666, 39.37204153557148, 50.85542344040504, 21.85956320777875, 31.97844685765432, 30.622898844067926, 54.535191293350984, 50.85542344040504, 52.77427870279417, 52.77427870279417, 31.27524829164734, 18.933732894860512, 50.55579120663061, 50.012214583404685, 62.31670379718742, 28.163630874931748, 47.84512348376523, 36.178285578426724, 45.59572521783024, 43.50856787437951, 62.31670379718742, 62.31670379718742, 19.859850154855067, 53.285803927267985, 59.30640384054761, 56.2961038839078, 43.50856787437951, 41.67212390491823]
, dtype = 'float32').reshape(1,600) # HOI loss weight
hoi_weight = torch.from_numpy(hoi_weight)
verb_mapping = pickle.load(open('verb_mapping.pkl', 'rb'), encoding='latin1')
verb_weight = np.matmul(verb_mapping, hoi_weight.transpose(1, 0).numpy())
verb_weight = torch.from_numpy((verb_weight.reshape(1, -1) / np.sum(verb_mapping, axis=1))).view(-1)
human_weight = torch.Tensor([ 0.01398818, 5.79008917, 7.52538203, 7.49830107, 5.46175109,
4.96587733, 6.62569903, 2.322367 , 2.68453539, 6.36591446,
3.49891052, 2.27423848, 4.23727284, 7.89049584, 6.93541908,
6.96857129, 6.50703033, 3.23165102, 5.5584068 , 9.80854903,
6.1608923 , 5.79563703, 9.11540185, 5.0706651 , 4.94279131,
6.27767383, 5.25305402, 6.11010013, 3.54779162, 7.55725724,
6.3143752 , 7.36955209, 5.16563852, 8.27085503, 11.96803328,
7.67301803, 10.58173892, 5.45332059, 7.2347631 , 7.05048841,
5.75676409, 7.4282909 , 8.02322046, 8.93951119, 8.8922583 ,
4.36181166, 4.13271768, 8.20683317, 5.44790319, 6.75309753,
5.07810259, 5.85067057, 9.15462257, 10.35859537, 12.37349839,
8.62399432, 7.59018202, 7.85170981, 8.90776249, 6.35368937,
5.50396433, 7.89049584, 8.20683317, 8.14666465, 3.86453901,
7.95465778, 11.45720766, 5.76347552, 6.44790659, 4.94071877,
8.55578607, 4.67683131, 8.73591223, 6.77693 , 9.30544546,
7.43185597, 6.19867116, 7.69136716, 6.58813593, 6.95296339,
7.89616158, 8.92351085, 6.79944502, 8.01039977, 9.80854903,
7.46452675, 7.61560712, 7.10306623, 6.32730895, 8.90776249,
7.82489856, 6.12933149, 8.02969297, 5.90813158, 6.04110726,
10.29405685, 9.13481994, 9.97560312, 11.12073542, 8.12500315,
5.89345383, 5.9649696 , 7.43902446, 7.1108082 , 6.96408698,
6.52129591, 7.70535341, 4.76363619, 9.51129751, 7.13175138,
8.98910813, 12.37349839, 12.37349839, 6.75672729, 9.84776975,
10.0709133 , 7.8196215 ])
object_weight = torch.Tensor([ 0.03768974, 7.32770681, 8.36394877, 8.52510064, 6.37384689,
6.08680075, 7.93549414, 3.42823752, 3.82750571, 7.49626217,
4.44774845, 3.21095296, 5.25946638, 8.65318243, 7.84294258,
7.84294258, 7.57549141, 4.09937092, 6.36561639, 9.88443919,
6.89618159, 6.7876108 , 10.33338941, 6.12869679, 5.91535234,
7.37674453, 6.80800976, 7.61209399, 4.65412734, 8.65318243,
7.43013799, 8.50884012, 6.33763819, 9.39078137, 12.34829243,
8.67834099, 11.24968015, 6.29855898, 8.15485697, 8.07162632,
6.73425234, 7.04498753, 9.52989418, 10.02101473, 9.86338578,
6.50474802, 6.2564171 , 9.2802395 , 6.31880982, 7.83743293,
5.9647858 , 6.79727271, 9.92792431, 11.53736222, 13.04143962,
8.10337501, 8.51423097, 9.05245557, 10.12366888, 7.66154226,
6.52155514, 9.0711477 , 9.44412735, 9.0901959 , 4.81674005,
9.14961932, 12.34829243, 6.85935471, 7.38719759, 6.05348779,
9.97338668, 5.7904493 , 9.80276116, 7.71356345, 10.26885089,
8.29650749, 7.19067495, 8.53058011, 7.77874943, 8.15109049,
9.2802395 , 9.95039716, 7.92945183, 8.78582691, 11.24968015,
8.5755315 , 8.78582691, 8.11056929, 7.40486589, 10.09700064,
8.97241286, 7.04002474, 9.23477713, 6.89832752, 7.3310126 ,
11.65514525, 10.30059959, 11.16963744, 11.78867665, 9.23477713,
6.9200439 , 7.09732876, 8.43626943, 8.13616484, 7.98838361,
7.45419096, 8.81460587, 5.69156591, 11.16963744, 8.07162632,
9.86338578, 12.63597451, 13.7345868 , 7.84018396, 10.90137345,
11.02653659, 8.53608976])
loss_weight = {
'labels_r': verb_weight.reshape(1, -1),
'labels_o': verb_weight.reshape(1, -1),
'labels_s': verb_weight
}
def restruct_pose(train_pose):
pose = {}
for item in train_pose:
key = item['image_id']
if key not in pose:
pose[key] = []
pose[key].append([item['bboxes'], item['keypoints']])
return pose
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.avg = 0
self.cnt = 0
def update(self, val, k):
self.avg = self.avg + (val - self.avg) * k / (self.cnt + k)
self.cnt += k
def __str__(self):
"""String representation for logging
"""
return '%.4f' % self.avg