-
Notifications
You must be signed in to change notification settings - Fork 260
/
dataset.py
137 lines (104 loc) · 3.79 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.utils.data as data
from PIL import Image
import os
import math
import functools
import copy
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def accimage_loader(path):
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def get_default_image_loader():
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
import accimage
return accimage_loader
else:
return pil_loader
def video_loader(video_dir_path, frame_indices, image_loader):
video = []
for i in frame_indices:
image_path = os.path.join(video_dir_path, 'image_{:05d}.jpg'.format(i))
if os.path.exists(image_path):
video.append(image_loader(image_path))
else:
return video
return video
def get_default_video_loader():
image_loader = get_default_image_loader()
return functools.partial(video_loader, image_loader=image_loader)
def load_annotation_data(data_file_path):
with open(data_file_path, 'r') as data_file:
return json.load(data_file)
def get_class_labels(data):
class_labels_map = {}
index = 0
for class_label in data['labels']:
class_labels_map[class_label] = index
index += 1
return class_labels_map
def get_video_names_and_annotations(data, subset):
video_names = []
annotations = []
for key, value in data['database'].items():
this_subset = value['subset']
if this_subset == subset:
if subset == 'testing':
video_names.append('test/{}'.format(key))
else:
label = value['annotations']['label']
video_names.append('{}/{}'.format(label, key))
annotations.append(value['annotations'])
return video_names, annotations
def make_dataset(video_path, sample_duration):
dataset = []
n_frames = len(os.listdir(video_path))
begin_t = 1
end_t = n_frames
sample = {
'video': video_path,
'segment': [begin_t, end_t],
'n_frames': n_frames,
}
step = sample_duration
for i in range(1, (n_frames - sample_duration + 1), step):
sample_i = copy.deepcopy(sample)
sample_i['frame_indices'] = list(range(i, i + sample_duration))
sample_i['segment'] = torch.IntTensor([i, i + sample_duration - 1])
dataset.append(sample_i)
return dataset
class Video(data.Dataset):
def __init__(self, video_path,
spatial_transform=None, temporal_transform=None,
sample_duration=16, get_loader=get_default_video_loader):
self.data = make_dataset(video_path, sample_duration)
self.spatial_transform = spatial_transform
self.temporal_transform = temporal_transform
self.loader = get_loader()
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path = self.data[index]['video']
frame_indices = self.data[index]['frame_indices']
if self.temporal_transform is not None:
frame_indices = self.temporal_transform(frame_indices)
clip = self.loader(path, frame_indices)
if self.spatial_transform is not None:
clip = [self.spatial_transform(img) for img in clip]
clip = torch.stack(clip, 0).permute(1, 0, 2, 3)
target = self.data[index]['segment']
return clip, target
def __len__(self):
return len(self.data)