forked from facebookresearch/hsd3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SEP.py
559 lines (473 loc) · 20.6 KB
/
SEP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
import copy
import math
import heapq
#import numba as nb
import numpy as np
import networkx as nx
from scipy.spatial.distance import cdist
from tqdm import tqdm
def calculate_adj_matrix(data, threshold=0.5, similarity='euclidean'):
if similarity == 'euclidean':
dists = cdist(data, data, 'euclidean')
adjacency = np.where(dists < threshold, 1, 0)
elif similarity == 'cosine':
dists = cdist(data, data, 'cosine')
adjacency = np.where(1 - dists < threshold, 1, 0)
else:
raise ValueError("Unknown similarity measure")
np.fill_diagonal(adjacency, 0)
return adjacency
def get_id():
i = 0
while True:
yield i
i += 1
def graph_parse(adj_matrix):
g_num_nodes = adj_matrix.shape[0]
adj_table = {}
VOL = 0
node_vol = []
for i in range(g_num_nodes):
n_v = 0
adj = set()
for j in range(g_num_nodes):
if adj_matrix[i,j] != 0:
n_v += adj_matrix[i,j]
VOL += adj_matrix[i,j]
adj.add(j)
adj_table[i] = adj
node_vol.append(n_v)
return g_num_nodes,VOL,node_vol,adj_table
#@nb.jit(nopython=True)
def cut_volume(adj_matrix,p1,p2):
c12 = 0
for i in range(len(p1)):
for j in range(len(p2)):
c = adj_matrix[p1[i],p2[j]]
if c != 0:
c12 += c
return c12
def LayerFirst(node_dict,start_id):
stack = [start_id]
while len(stack) != 0:
node_id = stack.pop(0)
yield node_id
if node_dict[node_id].children:
for c_id in node_dict[node_id].children:
stack.append(c_id)
def merge(new_ID, id1, id2, cut_v, node_dict):
new_partition = node_dict[id1].partition + node_dict[id2].partition
v = node_dict[id1].vol + node_dict[id2].vol
g = node_dict[id1].g + node_dict[id2].g - 2 * cut_v
child_h = max(node_dict[id1].child_h,node_dict[id2].child_h) + 1
new_node = PartitionTreeNode(ID=new_ID,partition=new_partition,children={id1,id2},
g=g, vol=v,child_h= child_h,child_cut = cut_v)
node_dict[id1].parent = new_ID
node_dict[id2].parent = new_ID
node_dict[new_ID] = new_node
def compressNode(node_dict, node_id, parent_id):
p_child_h = node_dict[parent_id].child_h
node_children = node_dict[node_id].children
node_dict[parent_id].child_cut += node_dict[node_id].child_cut
node_dict[parent_id].children.remove(node_id)
node_dict[parent_id].children = node_dict[parent_id].children.union(node_children)
for c in node_children:
node_dict[c].parent = parent_id
com_node_child_h = node_dict[node_id].child_h
node_dict.pop(node_id)
if (p_child_h - com_node_child_h) == 1:
while True:
max_child_h = max([node_dict[f_c].child_h for f_c in node_dict[parent_id].children])
if node_dict[parent_id].child_h == (max_child_h + 1):
break
node_dict[parent_id].child_h = max_child_h + 1
parent_id = node_dict[parent_id].parent
if parent_id is None:
break
def child_tree_deepth(node_dict,nid):
node = node_dict[nid]
deepth = 0
while node.parent is not None:
node = node_dict[node.parent]
deepth+=1
deepth += node_dict[nid].child_h
return deepth
def CompressDelta(node1,p_node):
a = node1.child_cut
v1 = node1.vol
v2 = p_node.vol
return a * math.log(v2 / v1)
def CombineDelta(node1, node2, cut_v, g_vol):
v1 = node1.vol
v2 = node2.vol
g1 = node1.g
g2 = node2.g
v12 = v1 + v2
return ((v1 - g1) * math.log(v12 / v1,2) + (v2 - g2) * math.log(v12 / v2,2) - 2 * cut_v * math.log(g_vol / v12,2)) / g_vol
class PartitionTreeNode():
def __init__(self, ID, partition, vol, g, children:set = None,parent = None,child_h = 0, child_cut = 0):
self.ID = ID
self.partition = partition
self.parent = parent
self.children = children
self.vol = vol
self.g = g
self.merged = False
self.child_h = child_h #不包括该节点的子树高度
self.child_cut = child_cut
def __str__(self):
return "{" + "{}:{}".format(self.__class__.__name__, self.gatherAttrs()) + "}"
def gatherAttrs(self):
return ",".join("{}={}"
.format(k, getattr(self, k))
for k in self.__dict__.keys())
class PartitionTree():
def __init__(self,adj_matrix):
self.adj_matrix = adj_matrix
self.tree_node = {}
self.g_num_nodes, self.VOL, self.node_vol, self.adj_table = graph_parse(adj_matrix)
self.id_g = get_id()
self.leaves = []
self.build_leaves()
def build_leaves(self):
for vertex in range(self.g_num_nodes):
ID = next(self.id_g)
v = self.node_vol[vertex]
leaf_node = PartitionTreeNode(ID=ID, partition=[vertex], g = v, vol=v)
self.tree_node[ID] = leaf_node
self.leaves.append(ID)
def build_sub_leaves(self,node_list,p_vol):
subgraph_node_dict = {}
ori_ent = 0
for vertex in node_list:
ori_ent += -(self.tree_node[vertex].g / self.VOL)\
* math.log2(self.tree_node[vertex].vol / p_vol)
sub_n = set()
vol = 0
for vertex_n in node_list:
c = self.adj_matrix[vertex,vertex_n]
if c != 0:
vol += c
sub_n.add(vertex_n)
sub_leaf = PartitionTreeNode(ID=vertex,partition=[vertex],g=vol,vol=vol)
subgraph_node_dict[vertex] = sub_leaf
self.adj_table[vertex] = sub_n
return subgraph_node_dict,ori_ent
def build_root_down(self):
root_child = self.tree_node[self.root_id].children
subgraph_node_dict = {}
ori_en = 0
g_vol = self.tree_node[self.root_id].vol
for node_id in root_child:
node = self.tree_node[node_id]
ori_en += -(node.g / g_vol) * math.log2(node.vol / g_vol)
new_n = set()
for nei in self.adj_table[node_id]:
if nei in root_child:
new_n.add(nei)
self.adj_table[node_id] = new_n
new_node = PartitionTreeNode(ID=node_id,partition=node.partition,vol=node.vol,g = node.g,children=node.children)
subgraph_node_dict[node_id] = new_node
return subgraph_node_dict, ori_en
def entropy(self,node_dict = None):
if node_dict is None:
node_dict = self.tree_node
ent = 0
for node_id,node in node_dict.items():
if node.parent is not None:
node_p = node_dict[node.parent]
node_vol = node.vol
node_g = node.g
node_p_vol = node_p.vol
ent += - (node_g / self.VOL) * math.log2(node_vol / node_p_vol)
return ent
def __build_k_tree(self,g_vol,nodes_dict:dict,k = None,):
min_heap = []
cmp_heap = []
nodes_ids = nodes_dict.keys()
new_id = None
for i in nodes_ids:
for j in self.adj_table[i]:
if j > i:
n1 = nodes_dict[i]
n2 = nodes_dict[j]
if len(n1.partition) == 1 and len(n2.partition) == 1:
cut_v = self.adj_matrix[n1.partition[0],n2.partition[0]]
else:
cut_v = cut_volume(self.adj_matrix,p1 = np.array(n1.partition),p2=np.array(n2.partition))
diff = CombineDelta(nodes_dict[i], nodes_dict[j], cut_v, g_vol)
heapq.heappush(min_heap, (diff, i, j, cut_v))
unmerged_count = len(nodes_ids)
#pbar = tqdm(total=unmerged_count, desc="Building k-tree")
while unmerged_count > 1:
if len(min_heap) == 0:
break
diff, id1, id2, cut_v = heapq.heappop(min_heap)
if nodes_dict[id1].merged or nodes_dict[id2].merged:
continue
nodes_dict[id1].merged = True
nodes_dict[id2].merged = True
new_id = next(self.id_g)
merge(new_id, id1, id2, cut_v, nodes_dict)
self.adj_table[new_id] = self.adj_table[id1].union(self.adj_table[id2])
for i in self.adj_table[new_id]:
self.adj_table[i].add(new_id)
#compress delta
if nodes_dict[id1].child_h > 0:
heapq.heappush(cmp_heap,[CompressDelta(nodes_dict[id1],nodes_dict[new_id]),id1,new_id])
if nodes_dict[id2].child_h > 0:
heapq.heappush(cmp_heap,[CompressDelta(nodes_dict[id2],nodes_dict[new_id]),id2,new_id])
unmerged_count -= 1
for ID in self.adj_table[new_id]:
if not nodes_dict[ID].merged:
n1 = nodes_dict[ID]
n2 = nodes_dict[new_id]
cut_v = cut_volume(self.adj_matrix,np.array(n1.partition), np.array(n2.partition))
new_diff = CombineDelta(nodes_dict[ID], nodes_dict[new_id], cut_v, g_vol)
heapq.heappush(min_heap, (new_diff, ID, new_id, cut_v))
root = new_id
if unmerged_count > 1:
#combine solitary node
# print('processing solitary node')
assert len(min_heap) == 0
unmerged_nodes = {i for i, j in nodes_dict.items() if not j.merged}
new_child_h = max([nodes_dict[i].child_h for i in unmerged_nodes]) + 1
new_id = next(self.id_g)
new_node = PartitionTreeNode(ID=new_id,partition=list(nodes_ids),children=unmerged_nodes,
vol=g_vol,g = 0,child_h=new_child_h)
nodes_dict[new_id] = new_node
for i in unmerged_nodes:
nodes_dict[i].merged = True
nodes_dict[i].parent = new_id
if nodes_dict[i].child_h > 0:
heapq.heappush(cmp_heap, [CompressDelta(nodes_dict[i], nodes_dict[new_id]), i, new_id])
root = new_id
if k is not None:
while nodes_dict[root].child_h > k:
diff, node_id, p_id = heapq.heappop(cmp_heap)
if child_tree_deepth(nodes_dict, node_id) <= k:
continue
children = nodes_dict[node_id].children
compressNode(nodes_dict, node_id, p_id)
if nodes_dict[root].child_h == k:
break
for e in cmp_heap:
if e[1] == p_id:
if child_tree_deepth(nodes_dict, p_id) > k:
e[0] = CompressDelta(nodes_dict[e[1]], nodes_dict[e[2]])
if e[1] in children:
if nodes_dict[e[1]].child_h == 0:
continue
if child_tree_deepth(nodes_dict, e[1]) > k:
e[2] = p_id
e[0] = CompressDelta(nodes_dict[e[1]], nodes_dict[p_id])
heapq.heapify(cmp_heap)
return root
def check_balance(self,node_dict,root_id):
root_c = copy.deepcopy(node_dict[root_id].children)
for c in root_c:
if node_dict[c].child_h == 0:
self.single_up(node_dict,c)
def single_up(self,node_dict,node_id):
new_id = next(self.id_g)
p_id = node_dict[node_id].parent
grow_node = PartitionTreeNode(ID=new_id, partition=node_dict[node_id].partition, parent=p_id,
children={node_id}, vol=node_dict[node_id].vol, g=node_dict[node_id].g)
node_dict[node_id].parent = new_id
node_dict[p_id].children.remove(node_id)
node_dict[p_id].children.add(new_id)
node_dict[new_id] = grow_node
node_dict[new_id].child_h = node_dict[node_id].child_h + 1
self.adj_table[new_id] = self.adj_table[node_id]
for i in self.adj_table[node_id]:
self.adj_table[i].add(new_id)
def root_down_delta(self):
if len(self.tree_node[self.root_id].children) < 3:
return 0 , None , None
subgraph_node_dict, ori_entropy = self.build_root_down()
g_vol = self.tree_node[self.root_id].vol
new_root = self.__build_k_tree(g_vol=g_vol,nodes_dict=subgraph_node_dict,k=2)
self.check_balance(subgraph_node_dict,new_root)
new_entropy = self.entropy(subgraph_node_dict)
delta = (ori_entropy - new_entropy) / len(self.tree_node[self.root_id].children)
return delta, new_root, subgraph_node_dict
def leaf_up_entropy(self,sub_node_dict,sub_root_id,node_id):
ent = 0
for sub_node_id in LayerFirst(sub_node_dict,sub_root_id):
if sub_node_id == sub_root_id:
sub_node_dict[sub_root_id].vol = self.tree_node[node_id].vol
sub_node_dict[sub_root_id].g = self.tree_node[node_id].g
elif sub_node_dict[sub_node_id].child_h == 1:
node = sub_node_dict[sub_node_id]
inner_vol = node.vol - node.g
partition = node.partition
ori_vol = sum(self.tree_node[i].vol for i in partition)
ori_g = ori_vol - inner_vol
node.vol = ori_vol
node.g = ori_g
node_p = sub_node_dict[node.parent]
ent += -(node.g / self.VOL) * math.log2(node.vol / node_p.vol)
else:
node = sub_node_dict[sub_node_id]
node.g = self.tree_node[sub_node_id].g
node.vol = self.tree_node[sub_node_id].vol
node_p = sub_node_dict[node.parent]
ent += -(node.g / self.VOL) * math.log2(node.vol / node_p.vol)
return ent
def leaf_up(self):
h1_id = set()
h1_new_child_tree = {}
id_mapping = {}
for l in self.leaves:
p = self.tree_node[l].parent
h1_id.add(p)
delta = 0
for node_id in h1_id:
candidate_node = self.tree_node[node_id]
sub_nodes = candidate_node.partition
if len(sub_nodes) == 1:
id_mapping[node_id] = None
if len(sub_nodes) == 2:
id_mapping[node_id] = None
if len(sub_nodes) >= 3:
sub_g_vol = candidate_node.vol - candidate_node.g
subgraph_node_dict,ori_ent = self.build_sub_leaves(sub_nodes,candidate_node.vol)
sub_root = self.__build_k_tree(g_vol=sub_g_vol,nodes_dict=subgraph_node_dict,k = 2)
self.check_balance(subgraph_node_dict,sub_root)
new_ent = self.leaf_up_entropy(subgraph_node_dict,sub_root,node_id)
delta += (ori_ent - new_ent)
h1_new_child_tree[node_id] = subgraph_node_dict
id_mapping[node_id] = sub_root
delta = delta / self.g_num_nodes
return delta,id_mapping,h1_new_child_tree
def leaf_up_update(self,id_mapping,leaf_up_dict):
for node_id,h1_root in id_mapping.items():
if h1_root is None:
children = copy.deepcopy(self.tree_node[node_id].children)
for i in children:
self.single_up(self.tree_node,i)
else:
h1_dict = leaf_up_dict[node_id]
self.tree_node[node_id].children = h1_dict[h1_root].children
for h1_c in h1_dict[h1_root].children:
assert h1_c not in self.tree_node
h1_dict[h1_c].parent = node_id
h1_dict.pop(h1_root)
self.tree_node.update(h1_dict)
self.tree_node[self.root_id].child_h += 1
def root_down_update(self, new_id , root_down_dict):
self.tree_node[self.root_id].children = root_down_dict[new_id].children
for node_id in root_down_dict[new_id].children:
assert node_id not in self.tree_node
root_down_dict[node_id].parent = self.root_id
root_down_dict.pop(new_id)
self.tree_node.update(root_down_dict)
self.tree_node[self.root_id].child_h += 1
def build_coding_tree(self, k=2, mode='v2'):
if k == 1:
return
if mode == 'v1' or k is None:
self.root_id = self.__build_k_tree(self.VOL, self.tree_node, k = k)
elif mode == 'v2':
self.root_id = self.__build_k_tree(self.VOL, self.tree_node, k = 2)
self.check_balance(self.tree_node,self.root_id)
if self.tree_node[self.root_id].child_h < 2:
self.tree_node[self.root_id].child_h = 2
flag = 0
while self.tree_node[self.root_id].child_h < k:
if flag == 0:
leaf_up_delta,id_mapping,leaf_up_dict = self.leaf_up()
root_down_delta, new_id , root_down_dict = self.root_down_delta()
elif flag == 1:
leaf_up_delta, id_mapping, leaf_up_dict = self.leaf_up()
elif flag == 2:
root_down_delta, new_id , root_down_dict = self.root_down_delta()
else:
raise ValueError
if leaf_up_delta < root_down_delta:
# print('root down')
# root down update and recompute root down delta
flag = 2
self.root_down_update(new_id,root_down_dict)
else:
# leaf up update
# print('leave up')
flag = 1
# print(self.tree_node[self.root_id].child_h)
self.leaf_up_update(id_mapping,leaf_up_dict)
# print(self.tree_node[self.root_id].child_h)
# update root down leave nodes' children
if root_down_delta != 0:
for root_down_id, root_down_node in root_down_dict.items():
if root_down_node.child_h == 0:
root_down_node.children = self.tree_node[root_down_id].children
count = 0
for _ in LayerFirst(self.tree_node, self.root_id):
count += 1
assert len(self.tree_node) == count
def load_graph(dname):
print('loading data')
g_list = []
label_dict = {}
feat_dict = {}
with open('datasets/%s/%s.txt' % (dname, dname.replace('-', '')), 'r') as f:
n_g = int(f.readline().strip())
for i in range(n_g):
row = f.readline().strip().split()
n, l = [int(w) for w in row]
if l not in label_dict:
mapped = len(label_dict)
label_dict[l] = mapped
g = nx.Graph()
node_tags = []
node_features = []
n_edges = 0
for j in range(n):
row = f.readline().strip().split()
tmp = int(row[1]) + 2
g.add_node(j, tag=row[0])
if tmp == len(row):
# no node attributes
row = [int(w) for w in row]
attr = None
else:
row, attr = [int(w) for w in row[:tmp]], np.array([float(w) for w in row[tmp:]])
if not row[0] in feat_dict:
mapped = len(feat_dict)
feat_dict[row[0]] = mapped
node_tags.append(feat_dict[row[0]])
if tmp > len(row):
node_features.append(attr)
n_edges += row[1]
for k in range(2, len(row)):
g.add_edge(j, row[k])
if node_features != []:
node_features = np.stack(node_features)
else:
node_features = None
assert len(g) == n
g_list.append({'G': g, 'label': l})
print("# data: %d\tlabel:%s" % (len(g_list), len(label_dict)))
return g_list
undirected_adj = [[0, 3, 5, 8, 0], [3, 0, 6, 4, 11],
[5, 6, 0, 2, 0], [8, 4, 2, 0, 10],
[0, 11, 0, 10, 0]]
undirected_adj = [[0, 1, 1, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 1, 0, 1, 0, 0, 0],
[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 1, 0, 1, 1],
[0, 0, 0, 0, 1, 0, 1],
[0, 0, 0, 0, 1, 1, 0]]
import numpy as np
import time
if __name__ == "__main__":
# Generate a 10240 x 10240 adjacency matrix
undirected_adj = np.random.randint(0, 2, (128, 128))
y = PartitionTree(adj_matrix=undirected_adj)
start_time = time.time()
x = y.build_coding_tree(2)
end_time = time.time()
for k, v in y.tree_node.items():
print(k, v.__dict__)
print(f"Time taken: {end_time - start_time} seconds")