-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess_data.py
192 lines (179 loc) · 7.52 KB
/
preprocess_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import argparse
import numpy as np
import torch
import random
import yaml
import pickle as pkl
from easydict import EasyDict as edict
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
import multiprocessing as mp
from functools import partial
from utils.graph_generators import *
from utils.arg_helper import mkdir
from utils.graph_utils import save_graph_list, load_graph_ts
import gc
from torch_geometric.utils.convert import from_networkx
def parse_arguments():
parser = argparse.ArgumentParser(
description="Script to generate the synthetic data for the DAMNETS experiments")
parser.add_argument(
'-n',
'--num_workers',
type=int,
default=1,
help='Number of workers used to generate the data'
)
parser.add_argument(
'-p',
'--data_path',
type=str,
default='data',
help='The data directory.'
)
parser.add_argument(
'-d',
'--dataset_name',
type=str,
default='',
help='The name of the pickled dataset in the directory, stored as lists of networkx graphs. '
'Do not include .pkl.'
)
parser.add_argument(
'-t',
'--train_p',
type=float,
help='The proportion of total data to use for training. The remainder will be used for test.'
)
parser.add_argument(
'-v',
'--val_p',
type=float,
help='The proportion of the training data to use for validation.'
)
parser.add_argument(
'-r',
'--randomize',
type=bool,
help='If True, will randomly select the time series to use for training/test/validation.',
default=True
)
args = parser.parse_args()
return args
def format_data(graph_ts, n_workers, abs=False):
'''
Prepare the data. This involves computing the delta matrices for the
network time series, inserting them into the TreeLib and then putting them into torch geometric
format for computation.
Args:
graph_ts: The time series to pre-process.
start_idx: If some time-series have already been processed, we want to start indexing
from the index of the last one in the previous batch (as they are accessed
via this index in the TreeLib underlying the bigg model). This is used for training and validation split.
Returns: A pytorch dataloader that loads the previous network combined with the index in the TreeLib of the
associated delta matrix we want to learn.
'''
print('Computing Deltas')
delta_fn = partial(compute_adj_delta, abs=False)
diffs = process_map(delta_fn, graph_ts, max_workers=n_workers)
# Remove the last observation from training data.
graph_ts = [ts[:-1] for ts in graph_ts]
# Flatten (will go into treelib in this order).
diffs = [nx.Graph(diff) for diff_ts in diffs for diff in diff_ts]
graph_ts = [g for ts in graph_ts for g in ts]
num_nodes = graph_ts[0].number_of_nodes()
ix = np.array([(i,j) for i in range(1, num_nodes) for j in range(i)])
prev_labels = [nx.to_numpy_array(g)[ix[:, 0], ix[:, 1]] for g in graph_ts]
print('Converting to networkx format')
data = process_map(from_networkx, graph_ts, max_workers=n_workers, chunksize=20)
# data = [from_networkx(g) for g in tqdm(graph_ts)] # convert to torch_geometric format w/ edgelists.
one_hot = torch.eye(num_nodes)
print('Setting attributes.')
for d in tqdm(data):
d.x = one_hot
# data[i].graph_id = i + start_idx
data = list(zip(prev_labels, data))
return list(zip(data, diffs))
# def format_data(graph_ts, n_workers, start_idx=0):
# '''
# Prepare the data. This involves computing the delta matrices for the
# network time series, inserting them into the TreeLib and then putting them into torch geometric
# format for computation.
# Args:
# graph_ts: The time series to pre-process.
# start_idx: If some time-series have already been processed, we want to start indexing
# from the index of the last one in the previous batch (as they are accessed
# via this index in the TreeLib underlying the bigg model). This is used for training and validation split.
# Returns: A pytorch dataloader that loads the previous network combined with the index in the TreeLib of the
# associated delta matrix we want to learn.
# '''
# print('Computing Deltas')
# delta_fn = partial(compute_adj_delta, abs=False)
# diffs = process_map(delta_fn, graph_ts, max_workers=n_workers)
# graph_ts = [ts[:-1] for ts in graph_ts]
# # Flatten (will go into treelib in this order).
# diffs_pos = [[(d == 1).astype(np.int) for d in diff_ts] for diff_ts in diffs]
# diffs_neg = [[(d == -1).astype(np.int) for d in diff_ts] for diff_ts in diffs]
# diffs_pos = [nx.Graph(diff) for diff_ts in diffs_pos for diff in diff_ts]
# diffs_neg = [nx.Graph(diff) for diff_ts in diffs_neg for diff in diff_ts]
# graph_ts = [g for ts in graph_ts for g in ts]
# num_nodes = graph_ts[0].number_of_nodes()
# print('Converting to pyg format')
# data = process_map(from_networkx, graph_ts, max_workers=n_workers, chunksize=20)
# # data = [from_networkx(g) for g in tqdm(graph_ts)] # convert to torch_geometric format w/ edgelists.
# one_hot = torch.eye(num_nodes)
# print('Setting attributes.')
# for i in tqdm(range(len(data))):
# data[i].x = one_hot
# data[i].pos_id = (2 * i) + start_idx
# data[i].neg_id = (2 * i + 1) + start_idx
# return (data, diffs_pos, diffs_neg), data[-1].neg_id + 1
def main():
c_args = parse_arguments()
save_dir = c_args.data_path
mkdir(save_dir)
graphs_path = os.path.join(save_dir, f'{c_args.dataset_name}.pkl')
graphs = load_graph_ts(graphs_path)
if c_args.randomize:
print('Randomizing')
random.shuffle(graphs)
train_ix = int(len(graphs) * c_args.train_p)
if train_ix == 0: # debugging
train_graphs = test_graphs = graphs
else:
train_graphs = graphs[:train_ix]
test_graphs = graphs[train_ix:]
val_len = int(len(train_graphs) * c_args.val_p)
val_graphs = train_graphs[:val_len] if val_len > 0 else train_graphs[val_len:]
# Remove validation from training set
train_graphs = train_graphs[val_len:]
## Set number of nodes for bigg model (keep names same for compatability)
print('Number of Training TS: ', len(train_graphs))
print('Number of Val TS: ', len(val_graphs))
print('Number of Test TS: ', len(test_graphs))
print('TS length (T): ', len(test_graphs[0]))
print('Saving Graphs')
save_graph_list(
train_graphs, os.path.join(save_dir, f'{c_args.dataset_name}_train_graphs_raw.pkl'))
save_graph_list(
val_graphs, os.path.join(save_dir, f'{c_args.dataset_name}_val_graphs_raw.pkl'))
save_graph_list(
test_graphs, os.path.join(save_dir, f'{c_args.dataset_name}_test_graphs.pkl'))
## test doesn't need any pre-processing
del test_graphs
del graphs
gc.collect()
## put into required format, save.
train_processed = format_data(train_graphs, c_args.num_workers)
save_graph_list(
train_processed, os.path.join(save_dir, f'{c_args.dataset_name}_train_graphs.pkl'))
del train_processed
del train_graphs
gc.collect()
val_processed = format_data(val_graphs, c_args.num_workers)
save_graph_list(
val_processed, os.path.join(save_dir, f'{c_args.dataset_name}_val_graphs.pkl'))
# args.experiment.graph_save_dir = self.args.save_dir
if __name__ == '__main__':
main()