-
Notifications
You must be signed in to change notification settings - Fork 0
/
seesaw.py
executable file
·899 lines (760 loc) · 41 KB
/
seesaw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
import os
import argparse
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
import copy
import logging
import wandb
from copy import deepcopy
from torch.nn import Parameter
#os.environ["CUDA_VISIBLE_DEVICES"]='1'
################## Helper function
def get_frame_num_per_cls(list_file, gt_path, actions_dict):
file_ptr = open(list_file, 'r')
list_of_examples = file_ptr.read().split('\n')[:-1]
file_ptr.close()
num_per_cls = np.zeros(len(actions_dict))
for vid in list_of_examples:
file_ptr = open(gt_path + vid, 'r')
contents = file_ptr.read().split('\n')[:-1]
for c in contents:
num_per_cls[actions_dict[c]] += 1
return num_per_cls
class Logger(object):
level_relations = {
'debug': logging.DEBUG,
'info': logging.INFO,
'warning': logging.WARNING,
'error': logging.ERROR,
'crit': logging.CRITICAL
}
def __init__(self, logpath, logfile, level='info',fmt='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'):
self.logger = logging.getLogger()
self.logger.setLevel(self.level_relations.get(level))
format_str = logging.Formatter(fmt)
if not os.path.exists(logpath):
os.makedirs(logpath)
fh = logging.FileHandler(logfile, mode='w')
fh.setLevel(logging.INFO)
fh.setFormatter(format_str)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(format_str)
self.logger.addHandler(fh)
self.logger.addHandler(ch)
################## Metrics
def get_labels_start_end_time(frame_wise_labels, bg_class=["background"]):
labels = []
starts = []
ends = []
last_label = frame_wise_labels[0]
if frame_wise_labels[0] not in bg_class:
labels.append(frame_wise_labels[0])
starts.append(0)
for i in range(len(frame_wise_labels)):
if frame_wise_labels[i] != last_label:
if frame_wise_labels[i] not in bg_class:
labels.append(frame_wise_labels[i])
starts.append(i)
if last_label not in bg_class:
ends.append(i)
last_label = frame_wise_labels[i]
if last_label not in bg_class:
ends.append(i)
return labels, starts, ends
def f_score(recognized, ground_truth, overlap, bg_class=["background"]):
p_label, p_start, p_end = get_labels_start_end_time(recognized, bg_class)
y_label, y_start, y_end = get_labels_start_end_time(ground_truth, bg_class)
tp = 0
fp = 0
hits = np.zeros(len(y_label))
for j in range(len(p_label)):
intersection = np.minimum(p_end[j], y_end) - np.maximum(p_start[j], y_start)
union = np.maximum(p_end[j], y_end) - np.minimum(p_start[j], y_start)
IoU = (1.0*intersection / union)*([p_label[j] == y_label[x] for x in range(len(y_label))])
# Get the best scoring segment
idx = np.array(IoU).argmax()
if IoU[idx] >= overlap and not hits[idx]:
tp += 1
hits[idx] = 1
else:
fp += 1
fn = len(y_label) - sum(hits)
return float(tp), float(fp), float(fn)
def overlap_f1(P, Y, overlap=.1, bg_class=["background"]):
TP, FP, FN = 0, 0, 0
for i in range(len(P)):
tp, fp, fn = f_score(P[i], Y[i], overlap, bg_class)
TP += tp
FP += fp
FN += fn
precision = TP / float(TP + FP + 1e-8)
recall = TP / float(TP + FN + 1e-8)
F1 = 2 * (precision * recall) / (precision + recall+1e-16)
F1 = np.nan_to_num(F1)
return F1 * 100
def accuracy(P, Y):
total = 0.
correct = 0
for i in range(len(P)):
total += len(Y[i])
correct += (P[i] == Y[i]).sum()
return torch.Tensor([100 * correct / total])
def levenstein_(p, y, norm=False):
m_row = len(p)
n_col = len(y)
D = np.zeros([m_row + 1, n_col + 1], 'float')
for i in range(m_row + 1):
D[i, 0] = i
for i in range(n_col + 1):
D[0, i] = i
for j in range(1, n_col + 1):
for i in range(1, m_row + 1):
if y[j - 1] == p[i - 1]:
D[i, j] = D[i - 1, j - 1]
else:
D[i, j] = min(D[i - 1, j] + 1,
D[i, j - 1] + 1,
D[i - 1, j - 1] + 1)
if norm:
score = (1 - D[-1, -1] / max(m_row, n_col)) * 100
else:
score = D[-1, -1]
return score
def edit_score(P, Y, norm=True, bg_class=["background"]):
if type(P) == list:
tmp = [edit_score(P[i], Y[i], norm, bg_class) for i in range(len(P))]
return np.mean(tmp)
else:
P_, _, _ = get_labels_start_end_time(P, bg_class)
Y_, _, _ = get_labels_start_end_time(Y, bg_class)
return levenstein_(P_, Y_, norm)
# balanced metric
def f_score_ana(recognized, ground_truth, overlap, actions_dict, bg_class=["background"]):
p_label, p_start, p_end = get_labels_start_end_time(recognized, bg_class)
y_label, y_start, y_end = get_labels_start_end_time(ground_truth, bg_class)
tp = np.zeros(len(actions_dict))
fp = np.zeros(len(actions_dict))
fn = np.zeros(len(actions_dict))
hits = np.zeros(len(y_label))
for j in range(len(p_label)):
intersection = np.minimum(p_end[j], y_end) - np.maximum(p_start[j], y_start)
union = np.maximum(p_end[j], y_end) - np.minimum(p_start[j], y_start)
IoU = (1.0 * intersection / union) * ([p_label[j] == y_label[x] for x in range(len(y_label))])
# Get the best scoring segment
idx = np.array(IoU).argmax()
if IoU[idx] >= overlap and not hits[idx]:
tp[actions_dict[p_label[j]]] += 1
hits[idx] = 1
else:
fp[actions_dict[p_label[j]]] += 1
for j in range(len(y_label)):
if hits[j] == 0:
fn[actions_dict[y_label[j]]] += 1
return tp, fp, fn
def overlap_f1_macro(P, Y, overlap=.1, bg_class=["background"]):
TP, FP, FN = 0, 0, 0
for i in range(len(P)):
tp, fp, fn = f_score_ana(P[i], Y[i], overlap, actions_dict, bg_class)
TP += tp
FP += fp
FN += fn
precision = TP / (TP + FP + 1e-8)
recall = TP / (TP + FN + 1e-8)
F1 = 2 * (precision * recall) / (precision + recall + 1e-16)
F1 = np.nan_to_num(F1)
return F1 * 100
def b_accuracy(P, Y):
total = np.zeros(len(actions_dict))
correct = np.zeros(len(actions_dict))
cover = np.zeros(len(actions_dict))
for i in range(len(P)):
num = min(len(P[i]), len(Y[i]))
for j in range(num):
if P[i][j] == Y[i][j]:
correct[actions_dict[Y[i][j]]] += 1
total[actions_dict[Y[i][j]]] += 1
cover[actions_dict[P[i][j]]] += 1
avg_acc = 100 * correct / (total + 1e-8)
avg_prec = 100 * correct / (cover + 1e-8)
return avg_acc, avg_prec
################## Dataloader
class BatchGenerator(torch.utils.data.Dataset):
def __init__(self, num_classes, actions_dict, gt_path, features_path, sample_rate, vid_list_file):
self.list_of_examples = list()
self.index = 0
self.num_classes = num_classes
self.actions_dict = actions_dict
self.gt_path = gt_path
self.features_path = features_path
self.sample_rate = sample_rate
self.read_data(vid_list_file)
def __len__(self):
return len(self.list_of_examples)
def read_data(self, vid_list_file):
file_ptr = open(vid_list_file, 'r')
self.list_of_examples = file_ptr.read().split('\n')[:-1]
file_ptr.close()
random.shuffle(self.list_of_examples)
def getitem(self, index):
vid = self.list_of_examples[index]
features = np.load(self.features_path + vid.split('.')[0] + '.npy')
file_ptr = open(self.gt_path + vid, 'r')
content = file_ptr.read().split('\n')[:-1]
classes = np.zeros(min(np.shape(features)[1], len(content)))
for i in range(len(classes)):
classes[i] = self.actions_dict[content[i]]
input = features[:, ::self.sample_rate]
target = classes[::self.sample_rate]
batch_input_tensor = torch.from_numpy(input).float()
batch_target_tensor = torch.from_numpy(target).long()
mask = torch.ones(self.num_classes, np.shape(input)[1])
return batch_input_tensor, batch_target_tensor, mask, vid
def __getitem__(self, index):
return self.getitem(index)
################## Model (change loss)
class MultiStageModel(nn.Module):
def __init__(self, num_stages, num_layers, num_f_maps, dim, num_classes, use_norm =0):
super(MultiStageModel, self).__init__()
self.stage1 = SingleStageModel(num_layers, num_f_maps, dim, num_classes)
self.stages = nn.ModuleList([copy.deepcopy(SingleStageModel(num_layers, num_f_maps, num_classes, num_classes)) for s in range(num_stages-2)])
self.stage2 = SingleStageModel(num_layers, num_f_maps, num_classes, num_classes, use_norm)
def forward(self, x, mask):
out = self.stage1(x, mask)
outputs = out.unsqueeze(0)
for s in self.stages:
out = s(F.softmax(out, dim=1) * mask[:, 0:1, :], mask)
outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
out = self.stage2(F.softmax(out, dim=1) * mask[:, 0:1, :], mask)
outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
return outputs
class NormedLinear(nn.Module):
def __init__(self, in_features, out_features):
super(NormedLinear, self).__init__()
self.weight = Parameter(torch.Tensor(in_features, out_features))
self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
def forward(self, x):
out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
return out
class SingleStageModel(nn.Module):
def __init__(self, num_layers, num_f_maps, dim, num_classes, use_norm =0):
super(SingleStageModel, self).__init__()
self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1)
self.layers = nn.ModuleList([copy.deepcopy(DilatedResidualLayer(2 ** i, num_f_maps, num_f_maps)) for i in range(num_layers)])
self.use_norm = use_norm
self.num_f_maps = num_f_maps
if self.use_norm:
self.conv_out = NormedLinear(num_f_maps, num_classes)
else:
self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
def forward(self, x, mask):
out = self.conv_1x1(x)
for layer in self.layers:
out = layer(out, mask)
if self.use_norm:
# only worked when batchsize =1
out = out.transpose(2, 1).contiguous().view(-1, self.num_f_maps)
out = self.conv_out(out)
out = (out.unsqueeze(0)).transpose(2, 1).contiguous() * mask[:, 0:1, :]
else:
out = self.conv_out(out) * mask[:, 0:1, :]
return out
class DilatedResidualLayer(nn.Module):
def __init__(self, dilation, in_channels, out_channels):
super(DilatedResidualLayer, self).__init__()
self.conv_dilated = nn.Conv1d(in_channels, out_channels, 3, padding=dilation, dilation=dilation)
self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
self.dropout = nn.Dropout()
def forward(self, x, mask):
out = F.relu(self.conv_dilated(x))
out = self.conv_1x1(out)
out = self.dropout(out)
return (x + out) * mask[:, 0:1, :]
def seesaw_ce_loss(cls_score, labels, cum_samples, num_classes, p, q, eps):
"""Calculate the Seesaw CrossEntropy loss.
Args:
cls_score (torch.Tensor): The prediction with shape (N, C),
C is the number of classes.
labels (torch.Tensor): The learning label of the prediction.
label_weights (torch.Tensor): Sample-wise loss weight.
cum_samples (torch.Tensor): Cumulative samples for each category.
num_classes (int): The number of classes.
p (float): The ``p`` in the mitigation factor.
q (float): The ``q`` in the compenstation factor.
eps (float): The minimal value of divisor to smooth
the computation of compensation factor
reduction (str, optional): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
Returns:
torch.Tensor: The calculated loss
"""
assert cls_score.size(-1) == num_classes
assert len(cum_samples) == num_classes
onehot_labels = F.one_hot(labels, num_classes)
seesaw_weights = cls_score.new_ones(onehot_labels.size())
# mitigation factor
if p > 0:
sample_ratio_matrix = cum_samples[None, :].clamp(min=1) / cum_samples[:, None].clamp(min=1)
index = (sample_ratio_matrix < 1.0).float()
sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index)
mitigation_factor = sample_weights[labels.long(), :]
seesaw_weights = seesaw_weights * mitigation_factor
# compensation factor
if q > 0:
scores = F.softmax(cls_score.detach(), dim=1)
self_scores = scores[
torch.arange(0, len(scores)).to(scores.device).long(),
labels.long()]
score_matrix = scores / self_scores[:, None].clamp(min=eps)
index = (score_matrix > 1.0).float()
compensation_factor = score_matrix.pow(q) * index + (1 - index)
seesaw_weights = seesaw_weights * compensation_factor
cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels))
loss = F.cross_entropy(cls_score, labels, weight=None, reduction='mean')
return loss
class SeesawLoss(nn.Module):
"""
Args:
p (float, optional): The ``p`` in the mitigation factor. Defaults to 0.8.
q (float, optional): The ``q`` in the compenstation factor. Defaults to 2.0.
num_classes (int, optional): The number of classes.
"""
def __init__(self, p=0.8, q=2.0, num_classes=48):
super(SeesawLoss, self).__init__()
self.p = p
self.q = q
self.num_classes = num_classes
self.cls_criterion = seesaw_ce_loss
self.cum_samples = torch.zeros(self.num_classes, dtype=torch.float).to(device)
self.eps = 1e-2
def forward(self, cls_score, labels):
"""
Args:
cls_score (torch.Tensor): The prediction with shape (N, C).
labels (torch.Tensor): The learning label of the prediction.
"""
assert cls_score.size(-1) == self.num_classes
# accumulate the samples for each category
unique_labels = labels.unique()
for u_l in unique_labels:
inds_ = (labels == u_l.item())
self.cum_samples[u_l] += inds_.sum()
loss_cls = self.cls_criterion(cls_score, labels, self.cum_samples, self.num_classes, self.p, self.q, self.eps)
return loss_cls
################## Trainer (change loss)
class Trainer:
def __init__(self, model, log, sample_rate, **kwargs):
set_seed(seed)
self.model = model(use_norm = args.norm, **kwargs)
self.ce = nn.CrossEntropyLoss(ignore_index=-100)
self.ce_split = nn.CrossEntropyLoss(ignore_index=-100, reduction='none')
self.mse = nn.MSELoss(reduction='none')
self.num_classes = kwargs.get('num_classes', 0)
self.sample_rate = sample_rate
assert self.num_classes > 0, "wrong class numbers"
self.log = log
self.log.logger.info('Model Size: {}'.format(sum(p.numel() for p in self.model.parameters())))
self.suppress_loss = SeesawLoss(float(args.p), float(args.q), self.num_classes)
def train(self, save_dir, num_epochs, batch_size, learning_rate, device, actions_dict,
batch_gen_tst=None):
self.model.train()
self.model.to(device)
resume_epoch = 0
optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
best_score = -10000
if args.resume > 0:
state = torch.load("./models/" + args.dataset + "/split_" + args.split + TYPE + "/epoch-last" + ".model",map_location = device)
resume_epoch = state['epoch'] + 1
self.model.load_state_dict(state['net'])
best_score = state['score']
self.suppress_loss.cum_samples = state['cum_samples']
for epoch in range(resume_epoch, num_epochs):
epoch_loss = 0
correct = 0
total = 0
ce_loss, smooth_loss = [0, 0, 0, 0], [0, 0, 0, 0]
for _, items in enumerate(train_loader):
batch_input, batch_target, mask, vids = items
batch_input, batch_target, mask = batch_input.to(device), batch_target.to(device), mask.to(device)
optimizer.zero_grad()
activity = vids[0].split('.txt')[0].split('_')[-1]
ps = self.model(batch_input, mask)
loss = 0
for i, p in enumerate(ps):
if i == len(ps) - 1:
s_ce_loss = self.suppress_loss(p.transpose(2, 1).contiguous().view(-1, self.num_classes),batch_target.view(-1))
else:
s_ce_loss = self.ce(p.transpose(2, 1).contiguous().view(-1, self.num_classes), batch_target.view(-1))
s_smooth_loss = 0.15 * torch.mean(torch.clamp(
self.mse(F.log_softmax(p[:, :, 1:], dim=1), F.log_softmax(p.detach()[:, :, :-1], dim=1)), min=0,
max=16) * mask[:, :, 1:])
loss += s_ce_loss
loss += s_smooth_loss
ce_loss[i] += s_ce_loss.item()
smooth_loss[i] += s_smooth_loss.item()
epoch_loss += loss.item()
loss.backward()
optimizer.step()
_, predicted = torch.max(ps.data[-1], 1)
correct += ((predicted == batch_target).float() * mask[:, 0, :].squeeze(1)).sum().item()
total += torch.sum(mask[:, 0, :]).item()
nums = len(batch_gen.list_of_examples)
pr_str = "[epoch %d]: loss = %f, ce1 = %f, ce2 = %f, ce3 = %f, ce4 = %f, " \
"sm1 = %f, sm2 = %f, sm3 = %f, sm4 = %f, acc = %f" % \
(epoch + 1, epoch_loss/nums, np.round(ce_loss[0]/nums,3), np.round(ce_loss[1]/nums,3),
np.round(ce_loss[2] / nums, 3), np.round(ce_loss[3]/nums,3), np.round(smooth_loss[0]/nums,3),
np.round(smooth_loss[1] / nums, 3), np.round(smooth_loss[2]/nums,3), np.round(smooth_loss[3]/nums,3),
float(correct) / total)
self.log.logger.info(pr_str)
test_score, test_log, CM_tst = self.test(epoch, actions_dict, device)
if test_score > best_score:
best_score = test_score
best_save = {'net': self.model.state_dict(), 'epoch': epoch, 'score': best_score, 'cum_samples': self.suppress_loss.cum_samples}
torch.save(best_save, save_dir + "/epoch-best" + ".model")
self.log.logger.info("Save for the best model")
last_save = {'net': self.model.state_dict(), 'epoch': epoch, 'score': best_score, 'cum_samples': self.suppress_loss.cum_samples}
torch.save(last_save, save_dir + "/epoch-last" + ".model")
log_dict = {'epoch': epoch}
log_dict.update(test_log)
if args.is_wandb:
train_log, CM = self.train_eval()
log_dict.update(train_log)
wandb.log(log_dict)
def train_eval(self):
self.model.eval()
total_frames, confusion_matrix = 0, np.zeros((num_classes, num_classes))
preds, labels = [], []
epoch_loss = 0
loss_d, conf_d = {}, {}
ce_loss, smooth_loss = [0, 0, 0, 0], [0, 0, 0, 0]
for i, items in enumerate(train_ad_loader):
batch_input, batch_target, mask, vids = items
batch_input, batch_target, mask = batch_input.to(device), batch_target.to(device), mask.to(device)
total_frames += len(batch_target.view(-1))
activity = vids[0].split('.txt')[0].split('_')[-1]
ps = self.model(batch_input, mask)
_, predicted = torch.max(ps[-1], 1)
pred1s, lbls = predicted.view(-1, 1).detach().cpu().numpy(), batch_target.view(-1, 1).cpu().numpy()
np.add.at(confusion_matrix, (lbls, pred1s), 1)
# loss
loss = 0
for i, p in enumerate(ps):
s_ce_loss_split = self.ce_split(p.transpose(2, 1).contiguous().view(-1, self.num_classes),
batch_target.view(-1))
s_smooth_loss = 0.15 * torch.mean(torch.clamp(self.mse(F.log_softmax(p[:, :, 1:], dim=1),
F.log_softmax(p.detach()[:, :, :-1], dim=1)),
min=0, max=16) * mask[:, :, 1:])
s_ce_loss = torch.mean(s_ce_loss_split)
if i == len(ps) - 1:
confidence = F.softmax(p, dim=1).squeeze(0)[
batch_target.view(-1), torch.arange(len(batch_target.view(-1)))]
current_l = np.unique(batch_target.view(-1).cpu().numpy())
for k in current_l:
if k not in loss_d:
loss_d[k] = []
conf_d[k] = []
tmpt_loss = s_ce_loss_split[batch_target.view(-1) == k]
tmpt_conf = confidence[batch_target.view(-1) == k]
loss_d[k].append(torch.mean(tmpt_loss).item())
conf_d[k].append(torch.mean(tmpt_conf).item())
loss += s_ce_loss
loss += s_smooth_loss
ce_loss[i] += s_ce_loss.item()
smooth_loss[i] += s_smooth_loss.item()
epoch_loss += loss.item()
_, predicted = torch.max(ps[-1], 1)
predicted = predicted.squeeze().cpu()
batch_target = batch_target.squeeze().cpu()
predicted_word, label_word = [], []
for i in range(len(predicted)):
predicted_word += [label_dict[predicted[i].item()]] * self.sample_rate
label_word += [label_dict[batch_target[i].item()]] * self.sample_rate
preds.append(np.array(predicted_word))
labels.append(np.array(label_word))
assert np.sum(confusion_matrix) == total_frames
confusion_matrix /= total_frames
results = {}
# action: standard metrics
results['f1_10'] = overlap_f1_macro(preds, labels, overlap=0.1)
results['f1_25'] = overlap_f1_macro(preds, labels, overlap=0.25)
results['f1_50'] = overlap_f1_macro(preds, labels, overlap=0.50)
results['f_rec'], results['f_prec'] = b_accuracy(preds, labels)
results1 = {}
# action: standard metrics
results1['f1_10'] = overlap_f1(preds, labels, overlap=0.1)
results1['f1_25'] = overlap_f1(preds, labels, overlap=0.25)
results1['f1_50'] = overlap_f1(preds, labels, overlap=0.50)
results1['f_acc'] = accuracy(preds, labels).item()
results1['edit'] = edit_score(preds, labels)
self.model.train()
head_loss = np.mean([np.mean(loss_d[i]) for i in range(num_classes) if ((i in loss_d) and (i in freq_list))])
com_loss = np.mean([np.mean(loss_d[i]) for i in range(num_classes) if ((i in loss_d) and (i in common_list))])
tail_loss = np.mean([np.mean(loss_d[i]) for i in range(num_classes) if ((i in loss_d) and (i in rare_list))])
head_conf = np.mean([np.mean(conf_d[i]) for i in range(num_classes) if ((i in conf_d) and (i in freq_list))])
com_conf = np.mean([np.mean(conf_d[i]) for i in range(num_classes) if ((i in conf_d) and (i in common_list))])
tail_conf = np.mean([np.mean(conf_d[i]) for i in range(num_classes) if ((i in conf_d) and (i in rare_list))])
head_recall = np.mean([results['f_rec'][i] for i in freq_list if i in loss_d])
com_recall = np.mean([results['f_rec'][i] for i in common_list if i in loss_d])
tail_recall = np.mean([results['f_rec'][i] for i in rare_list if i in loss_d])
head_prec = np.mean([results['f_prec'][i] for i in freq_list if i in loss_d])
com_prec = np.mean([results['f_prec'][i] for i in common_list if i in loss_d])
tail_prec = np.mean([results['f_prec'][i] for i in rare_list if i in loss_d])
f1 = 2 * results['f_rec'] * results['f_prec'] / (results['f_rec'] + results['f_prec'] + 1e-8)
head_f1 = np.mean([f1[i] for i in freq_list if i in loss_d])
com_f1 = np.mean([f1[i] for i in common_list if i in loss_d])
tail_f1 = np.mean([f1[i] for i in rare_list if i in loss_d])
nums = len(batch_gen_ad.list_of_examples)
log_dict = {'train_loss': ce_loss[3] / nums,
'train_head_loss': head_loss, 'train_com_loss': com_loss, 'train_tail_loss': tail_loss,
'train_head_conf': head_conf, 'train_com_conf': com_conf, 'train_tail_conf': tail_conf,
'train_mean_rec': np.mean([results['f_rec'][i] for i in range(num_classes) if i in loss_d]),
'train_mean_prec': np.mean([results['f_prec'][i] for i in range(num_classes) if i in loss_d]),
'train_head_rec': head_recall, 'train_com_rec': com_recall, 'train_rare_rec': tail_recall,
'train_head_prec': head_prec, 'train_com_prec': com_prec, 'train_rare_prec': tail_prec,
'train_head_f1': head_f1, 'train_com_f1': com_f1, 'train_rare_f1': tail_f1,
'train_glb_acc': results1['f_acc'], 'train_edit': results1['edit'],
'train_glb_f1_10': results1['f1_10'], 'train_glb_f1_25': results1['f1_25'], 'train_glb_f1_50': results1['f1_50'],
'train_cls_f1_10': np.mean([results['f1_10'][i] for i in range(num_classes) if i in loss_d]),
'train_cls_f1_25': np.mean([results['f1_25'][i] for i in range(num_classes) if i in loss_d]),
'train_cls_f1_50': np.mean([results['f1_50'][i] for i in range(num_classes) if i in loss_d]),}
return log_dict, confusion_matrix
def test(self, epoch, actions_dict, device):
self.model.eval()
total_frames, confusion_matrix = 0, np.zeros((num_classes, num_classes))
preds = []
labels = []
epoch_loss = 0
ce_loss, smooth_loss = [0, 0, 0, 0], [0, 0, 0, 0]
loss_d, conf_d = {}, {}
with torch.no_grad():
for i, items in enumerate(test_loader):
batch_input, batch_target, mask, vids = items
batch_input, batch_target, mask = batch_input.to(device), batch_target.to(device), mask.to(device)
total_frames += len(batch_target.view(-1))
activity = vids[0].split('.txt')[0].split('_')[-1]
ps = self.model(batch_input, mask)
loss = 0
for i, p in enumerate(ps):
s_ce_loss_split = self.ce_split(p.transpose(2, 1).contiguous().view(-1, self.num_classes),
batch_target.view(-1))
s_smooth_loss = 0.15 * torch.mean(torch.clamp(
self.mse(F.log_softmax(p[:, :, 1:], dim=1), F.log_softmax(p.detach()[:, :, :-1], dim=1)), min=0,
max=16) * mask[:, :, 1:])
s_ce_loss = torch.mean(s_ce_loss_split)
if i == len(ps) - 1:
confidence = F.softmax(p, dim=1).squeeze(0)[
batch_target.view(-1), torch.arange(len(batch_target.view(-1)))]
current_l = np.unique(batch_target.view(-1).cpu().numpy())
for k in current_l:
if k not in loss_d:
loss_d[k] = []
conf_d[k] = []
tmpt_loss = s_ce_loss_split[batch_target.view(-1) == k]
tmpt_conf = confidence[batch_target.view(-1) == k]
loss_d[k].append(torch.mean(tmpt_loss).item())
conf_d[k].append(torch.mean(tmpt_conf).item())
loss += s_ce_loss
loss += s_smooth_loss
ce_loss[i] += s_ce_loss.item()
smooth_loss[i] += s_smooth_loss.item()
epoch_loss += loss.item()
_, predicted = torch.max(ps[-1], 1)
pred1s, lbls = predicted.view(-1, 1).detach().cpu().numpy(), batch_target.view(-1, 1).cpu().numpy()
#pred_activity = self.decide_activity(pred1s, activity)
np.add.at(confusion_matrix, (lbls, pred1s), 1)
predicted = predicted.squeeze().cpu()
batch_target = batch_target.squeeze().cpu()
predicted_word, label_word = [], []
for i in range(len(predicted)):
predicted_word += [label_dict[predicted[i].item()]] * self.sample_rate
label_word += [label_dict[batch_target[i].item()]] * self.sample_rate
preds.append(np.array(predicted_word))
labels.append(np.array(label_word))
assert np.sum(confusion_matrix) == total_frames
confusion_matrix /= total_frames
nums = len(batch_gen_tst.list_of_examples)
pr_str = "***[epoch %d]***: loss = %f, ce1 = %f, ce2 = %f, ce3 = %f, ce4 = %f, " \
"sm1 = %f, sm2 = %f, sm3 = %f, sm4 = %f" % \
(epoch + 1, epoch_loss / nums, np.round(ce_loss[0] / nums, 3), np.round(ce_loss[1] / nums, 3),
np.round(ce_loss[2] / nums, 3), np.round(ce_loss[3] / nums, 3),
np.round(smooth_loss[0] / nums, 3),
np.round(smooth_loss[1] / nums, 3), np.round(smooth_loss[2] / nums, 3),
np.round(smooth_loss[3] / nums, 3))
self.log.logger.info(pr_str)
results1 = {}
# action: standard metrics
results1['f1_10'] = overlap_f1(preds, labels, overlap=0.1)
results1['f1_25'] = overlap_f1(preds, labels, overlap=0.25)
results1['f1_50'] = overlap_f1(preds, labels, overlap=0.50)
results1['f_acc'] = accuracy(preds, labels).item()
results1['edit'] = edit_score(preds, labels)
results = {}
# action: standard metrics
results['f1_10_s'] = overlap_f1_macro(preds, labels, overlap=0.1)
results['f1_25_s'] = overlap_f1_macro(preds, labels, overlap=0.25)
results['f1_50_s'] = overlap_f1_macro(preds, labels, overlap=0.50)
results['f_rec'], results['f_prec'] = b_accuracy(preds, labels)
test_not_appear = np.array([i for i in range(num_classes) if i not in loss_d])
valid_num = num_classes - len(test_not_appear)
results['f_acc'] = np.sum(results['f_rec']) / valid_num
results['f1_10'] = np.sum(results['f1_10_s']) / valid_num
results['f1_25'] = np.sum(results['f1_25_s']) / valid_num
results['f1_50'] = np.sum(results['f1_50_s']) / valid_num
results1['total_score'] = (results1['f1_10'] + results1['f1_25'] + results1['f1_50'])/3.0 + results1['f_acc'] + results1['edit'] + (results['f1_10'] + results['f1_25'] + results['f1_50'])/3.0 + results['f_acc']
self.log.logger.info(
"---[epoch %d]---: tst edit = %f, f1_10 = %f, f1_25 = %f, f1_50 = %f, acc = %f, total = %f "
% (epoch + 1, results1['edit'], results1['f1_10'], results1['f1_25'], results1['f1_50'], results1['f_acc'], results1['total_score']))
self.log.logger.info(
" balanced acc = %f, f1_10 = %f, f1_25 = %f, f1_50 = %f" % (results['f_acc'], results['f1_10'], results['f1_25'], results['f1_50']))
self.model.train()
head_loss = np.mean([np.mean(loss_d[i]) for i in range(num_classes) if ((i in loss_d) and (i in freq_list))])
com_loss = np.mean([np.mean(loss_d[i]) for i in range(num_classes) if ((i in loss_d) and (i in common_list))])
tail_loss = np.mean([np.mean(loss_d[i]) for i in range(num_classes) if ((i in loss_d) and (i in rare_list))])
head_conf = np.mean([np.mean(conf_d[i]) for i in range(num_classes) if ((i in conf_d) and (i in freq_list))])
com_conf = np.mean([np.mean(conf_d[i]) for i in range(num_classes) if ((i in conf_d) and (i in common_list))])
tail_conf = np.mean([np.mean(conf_d[i]) for i in range(num_classes) if ((i in conf_d) and (i in rare_list))])
head_recall = np.mean([results['f_rec'][i] for i in freq_list if i in loss_d])
com_recall = np.mean([results['f_rec'][i] for i in common_list if i in loss_d])
tail_recall = np.mean([results['f_rec'][i] for i in rare_list if i in loss_d])
head_prec = np.mean([results['f_prec'][i] for i in freq_list if i in loss_d])
com_prec = np.mean([results['f_prec'][i] for i in common_list if i in loss_d])
tail_prec = np.mean([results['f_prec'][i] for i in rare_list if i in loss_d])
f1 = 2 * results['f_rec'] * results['f_prec'] / (results['f_rec'] + results['f_prec'] + 1e-8)
head_f1 = np.mean([f1[i] for i in freq_list if i in loss_d])
com_f1 = np.mean([f1[i] for i in common_list if i in loss_d])
tail_f1 = np.mean([f1[i] for i in rare_list if i in loss_d])
exist_rec = np.array([results['f_rec'][i] for i in range(num_classes) if i in loss_d])
exist_prec = np.array([results['f_prec'][i] for i in range(num_classes) if i in loss_d])
log_dict = {'test_loss': ce_loss[3] / nums,
'test_head_loss': head_loss, 'test_com_loss': com_loss, 'test_tail_loss': tail_loss,
'test_head_conf': head_conf, 'test_com_conf': com_conf, 'test_tail_conf': tail_conf,
'test_mean_rec': np.mean(exist_rec), 'test_mean_prec': np.mean(exist_prec),
'test_head_rec': head_recall, 'test_com_rec': com_recall, 'test_rare_rec': tail_recall,
'test_head_prec': head_prec, 'test_com_prec': com_prec, 'test_rare_prec': tail_prec,
'test_head_f1': head_f1, 'test_com_f1': com_f1, 'test_rare_f1': tail_f1,
'test_glb_acc': results1['f_acc'], 'test_edit': results1['edit'],
'test_glb_f1_10': results1['f1_10'], 'test_glb_f1_25': results1['f1_25'],
'test_glb_f1_50': results1['f1_50'],
'test_cls_f1_10': results['f1_10'], 'test_cls_f1_25': results['f1_25'], 'test_cls_f1_50': results['f1_50']}
return results1['total_score'], log_dict, confusion_matrix
def predict(self, model_dir, results_dir, features_path, batch_gen_tst, actions_dict, sample_rate, device):
self.model.eval()
with torch.no_grad():
self.model.to(device)
state = torch.load(model_dir + "/epoch-best" + ".model", map_location = device)
self.model.load_state_dict(state['net'])
import time
time_start = time.time()
for i, items in enumerate(test_loader):
batch_input, batch_target, mask, vids = items
vid = vids[0]
features = np.load(features_path + vid.split('.')[0] + '.npy')
features = features[:, ::sample_rate]
input_x = torch.tensor(features, dtype=torch.float)
input_x.unsqueeze_(0)
input_x = input_x.to(device)
predictions = self.model(input_x, torch.ones(input_x.size(), device=device))
_, predicted = torch.max(predictions[-1].data, 1)
predicted = predicted.squeeze()
recognition = []
for i in range(len(predicted)):
recognition = np.concatenate((recognition, [list(actions_dict.keys())[
list(actions_dict.values()).index(
predicted[i].item())]] * sample_rate))
f_name = vid.split('/')[-1].split('.')[0]
f_ptr = open(results_dir + "/" + f_name, "w")
f_ptr.write("### Frame level recognition: ###\n")
f_ptr.write(' '.join(recognition))
f_ptr.close()
time_end = time.time()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser()
parser.add_argument('--action', default='train')
parser.add_argument('--dataset', default="breakfast")
parser.add_argument('--seed', default='42')
parser.add_argument('--split', default='1')
parser.add_argument('--p', default='0.4')
parser.add_argument('--q', default='0.5')
parser.add_argument('--norm', default=0, type=int)
parser.add_argument('--is_wandb', action='store_true', help='To log results on wandb')
parser.add_argument('--resume', default=0, type=int, help='do we resume form lastest saved model')
args = parser.parse_args()
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
seed = int(args.seed)
TYPE = '/mstcn_seesaw_p{}_q{}{}_{}'.format(args.p, args.q, '-norm' if args.norm else '', args.seed)
logpath = "results/" + args.dataset + "/split_{}/".format(args.split) + TYPE
logfile = logpath + '/' + time.strftime('%Y%m%d%H%M', time.localtime(time.time())) + '.log'
log = Logger(logpath, logfile, fmt="[%(asctime)s - %(levelname)s]: %(message)s")
log.logger.info('########################## MS-TCN #####################################')
log.logger.info("Training for MS-TCN with seesaw")
log.logger.info(args)
num_stages = 4
num_layers = 10
num_f_maps = 64
features_dim = 2048
bz = 1
lr = 0.0005
num_epochs = 50
# use the full temporal resolution @ 15fps
sample_rate = 1
# sample input features @ 15fps instead of 30 fps for 50salads, and up-sample the output to 30 fps
if args.dataset == "50salads":
sample_rate = 2
vid_list_file = "../../../data/" + args.dataset + "/splits/train.split" + args.split + ".bundle"
vid_list_file_tst = "../../../data/" + args.dataset + "/splits/test.split" + args.split + ".bundle"
features_path = "../../../data/" + args.dataset + "/features/"
gt_path = "../../../data/" + args.dataset + "/groundTruth/"
mapping_file = "../../../data/" + args.dataset + "/mapping.txt"
model_dir = "./models/" + args.dataset + "/split_" + args.split + TYPE
results_dir = "./results/" + args.dataset + "/split_" + args.split + TYPE
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if not os.path.exists(results_dir):
os.makedirs(results_dir)
file_ptr = open(mapping_file, 'r')
actions = file_ptr.read().split('\n')[:-1]
file_ptr.close()
actions_dict, label_dict = dict(), dict()
for a in actions:
actions_dict[a.split()[1]] = int(a.split()[0])
label_dict[int(a.split()[0])] = a.split()[1]
num_classes = len(actions_dict)
import json
with open('../data/breakfast_frame_bin3.json', 'r') as f:
group_dict = json.load(f)
freq_list = np.array([actions_dict[i] for i in group_dict['frequent']])
common_list = np.array([actions_dict[i] for i in group_dict['common']])
rare_list = np.array([actions_dict[i] for i in group_dict['rare']])
trainer = Trainer(MultiStageModel, log, sample_rate, num_stages = num_stages, num_layers = num_layers, num_f_maps = num_f_maps,
dim = features_dim, num_classes =num_classes)
if args.action == "train":
if args.is_wandb:
wandb.init(project='new_mstcn', entity='pang_neurips',
name='mstcn_seesaw_p{}_q{}{}_split{}_{}'.format(args.p, args.q, '-norm' if args.norm else '', args.split, args.seed), resume = (args.resume >0))
batch_gen = BatchGenerator(num_classes, actions_dict, gt_path, features_path, sample_rate, vid_list_file)
batch_gen_ad = BatchGenerator(num_classes, actions_dict, gt_path, features_path, sample_rate, vid_list_file)
batch_gen_tst = BatchGenerator(num_classes, actions_dict, gt_path, features_path, sample_rate, vid_list_file_tst)
train_loader = torch.utils.data.DataLoader(dataset=batch_gen, batch_size=1, shuffle=True, pin_memory=True, num_workers=2)
train_ad_loader = torch.utils.data.DataLoader(dataset=batch_gen_ad, batch_size=1, shuffle=True, pin_memory=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=batch_gen_tst, batch_size=1, shuffle=True, pin_memory=False, num_workers=2)
trainer.train(model_dir, num_epochs=num_epochs, batch_size=bz, learning_rate=lr, device=device,
actions_dict =actions_dict, batch_gen_tst=batch_gen_tst)
if args.is_wandb:
wandb.finish()
if args.action == "predict":
batch_gen_tst = BatchGenerator(num_classes, actions_dict, gt_path, features_path, sample_rate, vid_list_file_tst)
test_loader = torch.utils.data.DataLoader(dataset=batch_gen_tst, batch_size=1, shuffle=True, pin_memory=False, num_workers=2)
if not os.path.exists(os.path.join(results_dir,'prediction')):
os.makedirs(os.path.join(results_dir,'prediction'))
trainer.predict(model_dir, os.path.join(results_dir,'prediction'), features_path, batch_gen_tst, actions_dict, sample_rate, device)