forked from ChinaYi/ASFormer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbatch_gen.py
133 lines (110 loc) · 5.42 KB
/
batch_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
'''
Adapted from https://github.com/yabufarha/ms-tcn
'''
import torch
import numpy as np
import random
from grid_sampler import GridSampler, TimeWarpLayer
import os
class BatchGenerator(object):
def __init__(self, num_classes, actions_dict, gt_path, features_path, sample_rate):
self.index = 0
self.num_classes = num_classes
self.actions_dict = actions_dict
self.gt_path = gt_path
self.features_path = features_path
self.sample_rate = sample_rate
self.timewarp_layer = TimeWarpLayer()
def reset(self):
self.index = 0
self.my_shuffle()
def has_next(self):
if self.index < len(self.list_of_examples):
return True
return False
def read_data(self, vid_list):
# file_ptr = open(vid_list_file, 'r')
# self.list_of_examples = file_ptr.read().split('\n')[:-1]
# file_ptr.close()
self.list_of_examples = vid_list
self.gts = [os.path.join(self.gt_path, vid+'.txt') for vid in self.list_of_examples]
self.features = [os.path.join(self.features_path, vid+'.npy') for vid in self.list_of_examples]
self.my_shuffle()
def my_shuffle(self):
# shuffle list_of_examples, gts, features with the same order
randnum = random.randint(0, 100)
random.seed(randnum)
random.shuffle(self.list_of_examples)
random.seed(randnum)
random.shuffle(self.gts)
random.seed(randnum)
random.shuffle(self.features)
def warp_video(self, batch_input_tensor, batch_target_tensor):
'''
:param batch_input_tensor: (bs, C_in, L_in)
:param batch_target_tensor: (bs, L_in)
:return: warped input and target
'''
bs, _, T = batch_input_tensor.shape
grid_sampler = GridSampler(T)
grid = grid_sampler.sample(bs)
grid = torch.from_numpy(grid).float()
warped_batch_input_tensor = self.timewarp_layer(batch_input_tensor, grid, mode='bilinear')
batch_target_tensor = batch_target_tensor.unsqueeze(1).float()
warped_batch_target_tensor = self.timewarp_layer(batch_target_tensor, grid, mode='nearest') # no bilinear for label!
warped_batch_target_tensor = warped_batch_target_tensor.squeeze(1).long() # obtain the same shape
return warped_batch_input_tensor, warped_batch_target_tensor
def merge(self, bg, suffix):
'''
merge two batch generator. I.E
BatchGenerator a;
BatchGenerator b;
a.merge(b, suffix='@1')
:param bg:
:param suffix: identify the video
:return:
'''
self.list_of_examples += [vid + suffix for vid in bg.list_of_examples]
self.gts += bg.gts
self.features += bg.features
print('Merge! Dataset length:{}'.format(len(self.list_of_examples)))
def next_batch(self, batch_size, if_warp=False): # if_warp=True is a strong datashare augmentation. See grid_sampler.py for details.
batch = self.list_of_examples[self.index:self.index + batch_size]
batch_gts = self.gts[self.index:self.index + batch_size]
batch_features = self.features[self.index:self.index + batch_size]
self.index += batch_size
batch_input = []
batch_target = []
for idx, vid in enumerate(batch):
features = np.load(batch_features[idx])
file_ptr = open(batch_gts[idx], 'r')
gt_file_content = file_ptr.read().split('\n')[:-1]
content = []
for item in gt_file_content:
splitted_item = item.split()
content += [splitted_item[2]]*(int(splitted_item[1]) - int(splitted_item[0]) + 1)
classes = np.zeros(min(np.shape(features)[1], len(content)))
for i in range(len(classes)):
classes[i] = self.actions_dict[content[i]]
feature = features[:, ::self.sample_rate]
target = classes[::self.sample_rate]
batch_input.append(feature)
batch_target.append(target)
length_of_sequences = list(map(len, batch_target))
batch_input_tensor = torch.zeros(len(batch_input), np.shape(batch_input[0])[0], max(length_of_sequences), dtype=torch.float) # bs, C_in, L_in
batch_target_tensor = torch.ones(len(batch_input), max(length_of_sequences), dtype=torch.long) * (-100)
mask = torch.zeros(len(batch_input), self.num_classes, max(length_of_sequences), dtype=torch.float)
for i in range(len(batch_input)):
if if_warp:
warped_input, warped_target = self.warp_video(torch.from_numpy(batch_input[i]).unsqueeze(0), torch.from_numpy(batch_target[i]).unsqueeze(0))
batch_input_tensor[i, :, :np.shape(batch_input[i])[1]], batch_target_tensor[i, :np.shape(batch_target[i])[0]] = warped_input.squeeze(0), warped_target.squeeze(0)
else:
if batch[0] == 'P020_balloon2':
batch_input_tensor[i, :, :7406] = torch.from_numpy(batch_input[i][:,:7406])
else:
batch_input_tensor[i, :, :np.shape(batch_input[i])[1]] = torch.from_numpy(batch_input[i])
batch_target_tensor[i, :np.shape(batch_target[i])[0]] = torch.from_numpy(batch_target[i])
mask[i, :, :np.shape(batch_target[i])[0]] = torch.ones(self.num_classes, np.shape(batch_target[i])[0])
return batch_input_tensor, batch_target_tensor, mask, batch
if __name__ == '__main__':
pass