-
Notifications
You must be signed in to change notification settings - Fork 57
/
data_loader.py
103 lines (76 loc) · 3.13 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import time
import utils
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import ujson as json
class MySet(Dataset):
def __init__(self, input_file):
self.content = open('./data/' + input_file, 'r').readlines()
self.content = map(lambda x: json.loads(x), self.content)
self.lengths = map(lambda x: len(x['lngs']), self.content)
def __getitem__(self, idx):
return self.content[idx]
def __len__(self):
return len(self.content)
def collate_fn(data):
stat_attrs = ['dist', 'time']
info_attrs = ['driverID', 'dateID', 'weekID', 'timeID']
traj_attrs = ['lngs', 'lats', 'states', 'time_gap', 'dist_gap']
attr, traj = {}, {}
lens = np.asarray([len(item['lngs']) for item in data])
for key in stat_attrs:
x = torch.FloatTensor([item[key] for item in data])
attr[key] = utils.normalize(x, key)
for key in info_attrs:
attr[key] = torch.LongTensor([item[key] for item in data])
for key in traj_attrs:
# pad to the max length
seqs = np.asarray([item[key] for item in data])
mask = np.arange(lens.max()) < lens[:, None]
padded = np.zeros(mask.shape, dtype = np.float32)
padded[mask] = np.concatenate(seqs)
if key in ['lngs', 'lats', 'time_gap', 'dist_gap']:
padded = utils.normalize(padded, key)
padded = torch.from_numpy(padded).float()
traj[key] = padded
lens = lens.tolist()
traj['lens'] = lens
return attr, traj
class BatchSampler:
def __init__(self, dataset, batch_size):
self.count = len(dataset)
self.batch_size = batch_size
self.lengths = dataset.lengths
self.indices = range(self.count)
def __iter__(self):
'''
Divide the data into chunks with size = batch_size * 100
sort by the length in one chunk
'''
np.random.shuffle(self.indices)
chunk_size = self.batch_size * 100
chunks = (self.count + chunk_size - 1) // chunk_size
# re-arrange indices to minimize the padding
for i in range(chunks):
partial_indices = self.indices[i * chunk_size: (i + 1) * chunk_size]
partial_indices.sort(key = lambda x: self.lengths[x], reverse = True)
self.indices[i * chunk_size: (i + 1) * chunk_size] = partial_indices
# yield batch
batches = (self.count - 1 + self.batch_size) // self.batch_size
for i in range(batches):
yield self.indices[i * self.batch_size: (i + 1) * self.batch_size]
def __len__(self):
return (self.count + self.batch_size - 1) // self.batch_size
def get_loader(input_file, batch_size):
dataset = MySet(input_file = input_file)
batch_sampler = BatchSampler(dataset, batch_size)
data_loader = DataLoader(dataset = dataset, \
batch_size = 1, \
collate_fn = lambda x: collate_fn(x), \
num_workers = 4,
batch_sampler = batch_sampler,
pin_memory = True
)
return data_loader