forked from trneedham/Decorated-Merge-Trees
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDMT_tools.py
2425 lines (1747 loc) · 76.2 KB
/
DMT_tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import numpy as np
import matplotlib.pyplot as plt
from ripser import ripser
import persim
import networkx as nx
import random
import copy
from scipy.cluster.hierarchy import dendrogram, linkage, cut_tree
from scipy.optimize import linear_sum_assignment
from sklearn.metrics.pairwise import pairwise_distances
import gudhi as gd
import ot
from bisect import bisect_left
from hopcroftkarp import HopcroftKarp
"""
A collection of functions for manipulating, visualizing and comparing merge trees
and merge trees decorated with persistence data.
Code associated to the paper "Decorated Merge Trees for Persistent Topology", by
Justin Curry, Haibin Hang, Washington Mio, Tom Needham and Osman Okutan
Code by Tom Needham
"""
"""
Merge Tree class
"""
"""
Merge trees are represented as classes with basic attributes and visualization methods.
A merge tree class can be fit from:
- directly inputing the tree/node height
- a point cloud, assumed to be Euclidean, created using Vietoris-Rips
- a network, with node filtration specified
"""
"""
Helper Functions
"""
def invert_label_dict(label):
# Take dict {key:item}, return dict {item:key}
inverted_label = dict()
for key in label.keys():
if type(label[key]) == list:
for l in label[key]:
inverted_label[l] = key
else:
inverted_label[label[key]] = key
return inverted_label
def get_key(dictionary, val):
# Given a dictionary and a value, returns list of keys with that value
res = []
for key, value in dictionary.items():
if val == value:
res.append(key)
return res
def matrix_ell_infinity_distance(M1,M2):
# Inputs: two numpy arrays of the same size
# Output: \ell_\infty distance between the matrices
M = np.abs(M1 - M2)
dist = np.max(M)
return dist
def remove_short_bars(dgm, thresh=0.01):
dgm_thresh = dgm[dgm[:,1]-dgm[:,0] > thresh]
return dgm_thresh
def linkage_to_merge_tree(L,X):
nodeIDs = np.unique(L[:,:2]).astype(int)
num_leaves = X.shape[0]
edge_list = []
height = dict()
for j in range(num_leaves):
height[j] = 0
for j in range(L.shape[0]):
edge_list.append((int(L[j,0]),num_leaves+j))
edge_list.append((int(L[j,1]),num_leaves+j))
height[num_leaves+ j] = L[j,2]
T = nx.Graph()
T.add_nodes_from(nodeIDs)
T.add_edges_from(edge_list)
return T, height
"""
Network to Merge Tree Tools
Extracting a merge tree from a network requires more work.
"""
def perturb_function(f, perturbation = 1e-10):
if len(np.unique(list(f.values()))) < len(list(f.values())):
f = {j:f[j]+perturbation*np.random.rand() for j in f.keys()}
return f
def get_merge_tree_from_network(G,f):
# The algorithm needs unique entries to work. Check here and perturb if
# that's not the case.
if len(np.unique(list(f.values()))) < len(list(f.values())):
f = perturb_function(f,perturbation = 1e-6)
sorted_f = np.sort(np.unique(list(f.values())))
T = nx.Graph()
merge_tree_heights = {}
height = sorted_f[0]
subgraph_nodes = [j for j in list(f.keys()) if f[j] <= height]
H = G.subgraph(subgraph_nodes)
conn_comp_list = [list(c) for c in nx.connected_components(H)]
conn_comp_dict = {}
for c in conn_comp_list:
c_heights = [f[n] for n in c]
c_max_height = max(c_heights)
c_max_height_node = get_key(f,c_max_height)[-1]
T.add_node(c_max_height_node)
conn_comp_dict[c_max_height_node] = c
merge_tree_heights[c_max_height_node] = c_max_height
for k in range(1,len(sorted_f)):
plt.show()
conn_comp_dict_prev = conn_comp_dict
height = sorted_f[k]
subgraph_nodes = [j for j in list(f.keys()) if f[j] <= height]
H = G.subgraph(subgraph_nodes)
conn_comp_list = [list(c) for c in nx.connected_components(H)]
conn_comp_dict = {}
# Add vertices
for c in conn_comp_list:
c_heights = [f[n] for n in c]
c_max_height = max(c_heights)
c_max_height_node = get_key(f,c_max_height)[-1]
T.add_node(c_max_height_node)
conn_comp_dict[c_max_height_node] = c
merge_tree_heights[c_max_height_node] = c_max_height
# Add edges from previous level
for child_node in conn_comp_dict_prev.keys():
for parent_node in conn_comp_dict.keys():
if child_node in conn_comp_dict[parent_node] and child_node != parent_node:
T.add_edge(child_node,parent_node)
return T, merge_tree_heights
def get_barcodes_from_filtered_network(G,f,infinity = None):
# Initialize with an empty simplex tree
spCpx = gd.SimplexTree()
# Add edges from the adjacency graph
for edge in G.edges:
spCpx.insert(list(edge))
# Insert a single 2-dimensional simplex
spCpx.insert([0,1,2])
# Add filtration values to vertices
zero_skeleton = spCpx.get_skeleton(0)
for (j,spx) in enumerate(zero_skeleton):
spCpx.assign_filtration(spx[0], filtration=f[j])
# Extend to filtration of the whole complex
spCpx.make_filtration_non_decreasing()
# Compute persistence and extract barcodes
BarCodes = spCpx.persistence()
dgm0 = spCpx.persistence_intervals_in_dimension(0)
dgm1 = spCpx.persistence_intervals_in_dimension(1)
# Truncate infinite deg-1 bars to end at the maximum filtration value
# OR the predefined value of infinity, which may be data-driven
dgm1_fixed = []
if infinity == None:
max_val = np.max(list(f.values()))
else:
max_val = infinity
for bar in dgm1:
if bar[1] == np.inf:
new_bar = [bar[0],max_val]
else:
new_bar = bar
dgm1_fixed.append(new_bar)
dgm1 = np.array(dgm1_fixed)
return spCpx, dgm0, dgm1
def find_nearest(heights, value):
array = np.array(list(heights.values()))
idx = (np.abs(array - value)).argmin()
best_val = array[idx]
node_idx = get_key(heights,best_val)[0]
return node_idx
def decorate_merge_tree_networks(T,heights,dgm):
"""
Inputs: merge tree (T,height), degree-1 persistence diagram as a list of lists of birth-death pairs
Output: dictionary of barcodes associated to each leaf of the merge tree
"""
root = get_key(heights,max(list(heights.values())))[0]
leaf_barcodes = {n:[] for n in T.nodes() if T.degree(n) == 1 and n != root}
for bar in dgm:
birth = bar[0]
cycle_ind = find_nearest(heights,birth)
descendent_leaves = get_descendent_leaves(T,heights,cycle_ind)
non_descendent_leaves = [n for n in T.nodes() if T.degree(n)==1 and n not in descendent_leaves and n != root]
non_descendent_LCAs = dict()
for n in non_descendent_leaves:
LCA_idx_tmp, LCA_height_tmp = least_common_ancestor(T,heights,n,cycle_ind)
non_descendent_LCAs[n] = LCA_idx_tmp[0]
for leaf in descendent_leaves:
leaf_barcodes[leaf] = leaf_barcodes[leaf]+[list(bar)]
for leaf in non_descendent_leaves:
ancestor = non_descendent_LCAs[leaf]
truncated_bar = truncate_bar(bar,heights[ancestor])
if type(truncated_bar) == list:
leaf_barcodes[leaf] = leaf_barcodes[leaf] + [list(truncated_bar)]
return leaf_barcodes
"""
Diffusion Frechet Functions - for filtering networks by density
"""
# Find eigenvalues and vectors for graph Laplacian
def laplacian_eig(G):
# Input: Networkx graph
# Output: eigenvalues and eigenvectors of graph laplacian
L = nx.laplacian_matrix(G).toarray()
lam, phi = np.linalg.eigh(L)
return lam, phi
# Create heat kernel matrix from precomputed eigenvalues/tangent_vectors
def heat_kernel(lam,phi,t):
# Input: eigenvalues and eigenvectors for normalized Laplacian, time parameter t
# Output: heat kernel matrix
u = np.matmul(phi,np.matmul(np.diag(np.exp(-t*lam)),phi.T))
return u
def diffusion_distance_matrix(lam,phi,t):
dist = np.zeros([len(lam),len(lam)])
HK = heat_kernel(lam,phi,t)
for i in range(len(lam)):
v1 = HK[:,i]
for j in range(i+1,len(lam)):
v2 = HK[:,j]
dist[i,j] = np.linalg.norm(v1 - v2)
dist = dist + dist.T
return dist
def diffusion_frechet_function(dist,mu):
f = {j:np.dot(dist[j,:]**2,mu) for j in range(dist.shape[0])}
return f
def get_diffusion_frechet_function(G,t,mu = None):
# Reorder the node labels to be safe
Adj = nx.to_numpy_array(G)
G = nx.from_numpy_array(Adj)
# Get spectral and distributional information
lam, phi = laplacian_eig(G)
dist = diffusion_distance_matrix(lam,phi,t)
if mu is None:
mu = ot.unif(len(G))
# Get diffusion frechet
f = diffusion_frechet_function(dist,mu)
return f
"""
Distance to a Node Function
"""
def distance_to_a_node(G,node_id):
D = nx.floyd_warshall_numpy(G)
f = {n:D[node_id,n] for n in G.nodes()}
f = {n:-f[n] for n in range(len(f))}
mm = min(list(f.values()))
f = {n:f[n] - mm for n in range(len(f))}
return f
"""
Manipulating Merge Trees
"""
def threshold_merge_tree(T,height,thresh):
"""
Takes a merge tree and truncates it at the given threshold level.
Makes a cut at threshold height, removes all lower vertices.
"""
subdiv_heights = [thresh]
T_sub, height_sub = subdivide_edges(T,height,subdiv_heights)
height_array = np.array(list(set(height_sub.values())))
height_array_thresh = height_array[height_array >= thresh]
kept_nodes = []
for j in range(len(height_array_thresh)):
kept_nodes += get_key(height_sub,height_array_thresh[j])
T_thresh = T_sub.subgraph(kept_nodes).copy()
height_thresh = {n:height_sub[n] for n in kept_nodes}
root = get_key(height_thresh,max(list(height_thresh.values())))[0]
T_thresh_leaves = [n for n in T_thresh.nodes() if T_thresh.degree(n) == 1 and n != root]
for n in T_thresh_leaves:
descendents = get_descendent_leaves(T_sub,height_sub,n)
descendent_node_rep = list(set(get_key(height_sub,min([height[node] for node in descendents]))).intersection(set(descendents)))[0]
T_thresh.add_edge(n,descendent_node_rep)
height_thresh[descendent_node_rep] = height_sub[descendent_node_rep]
return T_thresh, height_thresh
"""
Main class definition
"""
class MergeTree:
"""
Creates a merge tree from (exactly) one of the three types of inputs:
- T and height: the merge tree is defined directly from a nx.Graph() T (which must be
a tree!) and a height function on the nodes (which must satisfy
certain conditions). The height function should be a dictionary whose keys
are node labels of T.
- pointCloud: a Euclidean point cloud of shape (num points) x (dimension). The merge tree
is generated from the Vietoris-Rips filtration of the point cloud
- network and network filtration: the merge tree is created from connectivity information
corrsponding to the input network and filtration function on its nodes.
network should be a nx.Graph() object.
network_filtration has the following options:
- a dictionary whose keys are node labels from the network and values are positive numbers
- network_filtration = 'Diffusion', in which case the diffusion filtration function is
used, with a variable t_diffusion parameter
- network_filtration = 'Distance', in which case the distance to a specified node is used,
with the node chosen by the user. The filtration is *superlevel set* for this function.
Merge trees can be 'decorated' with higher-dimensional homological data by the 'fit_barcode' method.
The result is a 'decorated merge tree'.
"""
def __init__(self,
tree = None,
height = None,
pointCloud = None,
network = None,
network_filtration = None,
t_diffusion = 0.1,
node_distance = 0,
simplify = True):
self.T = tree
self.pointCloud = pointCloud
self.network = pointCloud
self.network_filtration = network_filtration
self.t_diffusion = t_diffusion
self.node_distance = node_distance
self.leaf_barcode = None
self.ultramatrix = None
self.label = None
self.inverted_label = None
# Define merge tree from tree/height data
if tree is not None:
if pointCloud is not None:
raise Exception('Only enter tree data OR point cloud data OR network data')
elif network is not None:
raise Exception('Only enter tree data OR point cloud data OR network data')
if nx.is_tree(tree):
self.tree = tree
else:
raise Exception('Input Graph must be a tree')
if set(height.keys()) == set(tree.nodes):
self.height = height
else:
raise Exception('height keys must match node keys')
# Define merge tree from point cloud data
elif pointCloud is not None:
if tree is not None:
raise Exception('Only enter tree data OR point cloud data OR network data')
elif network is not None:
raise Exception('Only enter tree data OR point cloud data OR network data')
L = linkage(pointCloud)
T, height = linkage_to_merge_tree(L,pointCloud)
self.tree = T
self.height = height
# Define merge tree from network data
elif network is not None:
if tree is not None:
raise Exception('Only enter tree data OR point cloud data OR network data')
elif pointCloud is not None:
raise Exception('Only enter tree data OR point cloud data OR network data')
if network_filtration is None:
raise Exception('Network merge tree requires filtration---see documentation')
elif network_filtration == 'Diffusion':
f = get_diffusion_frechet_function(network,t_diffusion)
T, height = get_merge_tree_from_network(network,f)
self.filtration = f
self.tree = T
self.height = height
elif network_filtration == 'Distance':
if node_distance in list(network.nodes()):
f = distance_to_a_node(network,node_distance)
T, height = get_merge_tree_from_network(network,f)
self.filtration = f
self.tree = T
self.height = height
else:
raise Exception('Must enter a valid node to use Distance filtration')
elif type(network_filtration) == dict:
if set(network_filtration.keys()) == set(network.nodes()):
T, height = get_merge_tree_from_network(network,network_filtration)
self.filtration = network_filtration
self.tree = T
self.height = height
else:
raise Exception('Filtration keys must match node keys')
else:
raise Exception('network_filtration is not valid')
# Cleans up the merge tree by removing degree-2 nodes
if simplify:
TNew, heightNew = simplify_merge_tree(self.tree,self.height)
self.tree = TNew
self.height = heightNew
"""
Creating a Decorated Merge Tree
"""
def fit_barcode(self,
degree = 1,
leaf_barcode = None):
if leaf_barcode is not None:
self.leaf_barcode = leaf_barcode
else:
if self.T is not None:
raise Exception('fit_barcode for directly defined merge tree requires leaf_barcode input')
if self.pointCloud is not None:
dgm = ripser(self.pointCloud,maxdim = degree)['dgms'][-1]
leaf_barcode_init = decorate_merge_tree(self.tree, self.height, self.pointCloud, dgm)
leaf_barcode = {key: [bar for bar in leaf_barcode_init[key] if bar[1]-bar[0] > 0] for key in leaf_barcode_init.keys()}
self.barcode = dgm
self.leaf_barcode = leaf_barcode
"""
Getting the ultramatrix from a labeling of the merge tree
"""
def fit_ultramatrix(self,label = None):
if label is None:
label = {n:j for (j,n) in enumerate(self.tree.nodes())}
ultramatrix, inverted_label = get_ultramatrix(self.tree,self.height,label)
self.ultramatrix = ultramatrix
self.label = label
self.inverted_label = inverted_label
"""
Merge tree manipulation
"""
def threshold(self,threshold):
if self.leaf_barcode is None:
T_thresh, height_thresh = threshold_merge_tree(self.tree,self.height,threshold)
self.tree = T_thresh
self.height = height_thresh
self.ultramatrix = None
self.label = None
self.inverted_label = None
else:
T_thresh, height_thresh, leaf_barcode_thresh = simplify_decorated_merge_tree(self.tree,self.height,self.leaf_barcode,threshold)
self.tree = T_thresh
self.height = height_thresh
self.leaf_barcode = leaf_barcode_thresh
self.ultramatrix = None
self.label = None
self.inverted_label = None
def copy(self):
return copy.copy(self)
"""
Visualization Tools
"""
# For general merge trees
def draw(self, axes = False):
draw_merge_tree(self.tree,self.height,axes = axes)
# For merge trees coming from a network
def draw_network(self):
draw_network_and_function(network,self.filtration)
def draw_network_with_merge_tree(self):
draw_network_and_merge_tree(network,self.filtration)
def draw_with_labels(self,label):
draw_labeled_merge_tree(self.tree,self.height,label)
def draw_decorated(self,tree_thresh,barcode_thresh):
if self.pointCloud is not None:
_, _, _, _, _ = visualize_DMT_pointcloud(self.tree,
self.height,
self.barcode,
self.pointCloud,
tree_thresh,
barcode_thresh)
"""
Plotting Functions
"""
def mergeTree_pos(G, height, root=None, width=1.0, xcenter = 0.5):
'''
Adapted from Joel's answer at https://stackoverflow.com/a/29597209/2966723.
Licensed under Creative Commons Attribution-Share Alike
If the graph is a tree this will return the positions to plot this in a
hierarchical layout.
G: the graph (must be a tree)
height: dictionary {node:height} of heights for the vertices of G.
Must satisfy merge tree conditions, but this is not checked in this version of the function.
width: horizontal space allocated for this branch - avoids overlap with other branches
xcenter: horizontal location of root
'''
if not nx.is_tree(G):
raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')
height_vals = list(height.values())
max_height = max(height_vals)
root = get_key(height,max_height)[0]
# The root for the tree is the vertex with maximum height value
vert_loc = max_height
def _hierarchy_pos(G, root, vert_loc, width=1., xcenter = 0.5, pos = None, parent = None):
'''
see hierarchy_pos docstring for most arguments
pos: a dict saying where all nodes go if they have been assigned
parent: parent of this branch. - only affects it if non-directed
'''
if pos is None:
pos = {root:(xcenter,vert_loc)}
else:
pos[root] = (xcenter, vert_loc)
children = list(G.neighbors(root))
if not isinstance(G, nx.DiGraph) and parent is not None:
children.remove(parent)
if len(children)!=0:
dx = width/len(children)
nextx = xcenter - width/2 - dx/2
for child in children:
nextx += dx
vert_loc = height[child]
pos = _hierarchy_pos(G, child, vert_loc, width = dx, xcenter=nextx,
pos=pos, parent = root)
return pos
return _hierarchy_pos(G, root, vert_loc, width, xcenter)
def draw_merge_tree(G,height,axes=False):
# Input: merge tree as G, height
# Output: draws the merge tree with correct node heights
pos = mergeTree_pos(G,height)
fig, ax = plt.subplots()
nx.draw_networkx(G, pos=pos, with_labels=True)
if axes:
ax.tick_params(left=True, bottom=False, labelleft=True, labelbottom=False)
return
def draw_labeled_merge_tree(T,height,label,axes = False):
# Input: merge tree as T, height. Label dictionary label with labels for certain nodes
# Output: draws the merge tree with labels over the labeled nodes
pos = mergeTree_pos(T,height)
draw_labels = dict()
for key in label.keys():
draw_labels[key] = str(label[key])
nx.draw_networkx(T, pos = pos, labels = draw_labels, node_color = 'r', font_weight = 'bold', font_size = 16)
if axes:
ax.tick_params(left=True, bottom=False, labelleft=True, labelbottom=False)
return
"""
Merge Tree Processing
"""
def least_common_ancestor(G,height,vertex1,vertex2):
height_vals = list(height.values())
max_height = max(height_vals)
root = get_key(height,max_height)[0]
shortest_path1 = nx.shortest_path(G, source = vertex1, target = root)
shortest_path2 = nx.shortest_path(G, source = vertex2, target = root)
common_vertices = list(set(shortest_path1).intersection(set(shortest_path2)))
LCA_height = min([height[n] for n in common_vertices])
LCA_idx = get_key(height,LCA_height)
return LCA_idx, LCA_height
def get_ultramatrix(T,height,label,return_inverted_label = True):
"""
Gets an ultramatrix from a labeled merge tree.
Input: T, height are data from a merge tree (tree structure and height function dictionary),
label is a dictionary of node labels of the form {node:label}, where labels are given by a function
{0,1,\ldots,N} --> T, which is surjective onto the set of leaves of T.
Output: matrix with (i,j) entry the height of the least common ancestor of nodes labeled i and j. Optionally
returns the inverted label dictionary {label:node}, which is useful downstream.
"""
inverted_label = invert_label_dict(label)
ultramatrix = np.zeros([len(label),len(label)])
for j in range(len(label)):
ultramatrix[j,j] = height[inverted_label[j]]
sorted_heights = np.sort(np.unique(list(height.values())))[::-1]
old_node = get_key(height,sorted_heights[0])[0]
for h in sorted_heights:
node_list = get_key(height,h)
for node in node_list:
T_with_node_removed = T.copy()
T_with_node_removed.remove_node(node)
conn_comp_list = [list(c) for c in nx.connected_components(T_with_node_removed)]
descendent_conn_comp_list = [c for c in conn_comp_list if old_node not in c]
for c in descendent_conn_comp_list:
ultramatrix[label[node],[label[i] for i in c]] = h
for j in range(len(descendent_conn_comp_list)-1):
c = descendent_conn_comp_list[j]
for k in range(j+1,len(descendent_conn_comp_list)):
cc = descendent_conn_comp_list[k]
for i in c:
ultramatrix[label[i],[label[i] for i in cc]] = h
old_node = node
ultramatrix = np.maximum(ultramatrix,ultramatrix.T)
return ultramatrix, inverted_label
"""
Matching merge trees and estimating interleaving distance
"""
def get_heights(height1,height2,mesh):
initial_heights = list(set(list(height1.values()) + list(height2.values())))
M = max(initial_heights)
m = min(initial_heights)
num_samples = int(np.floor((M-m)/mesh))
all_heights = np.linspace(m,M,num_samples+1)
return all_heights
def subdivide_edges_single_height(G,height,subdiv_height):
G_sub = G.copy()
height_sub = height.copy()
node_idx = max(G.nodes()) + 1
for edge in G.edges():
if (height[edge[0]] < subdiv_height and subdiv_height < height[edge[1]]) or (height[edge[1]] < subdiv_height) and (subdiv_height < height[edge[0]]):
G_sub.add_node(node_idx)
G_sub.add_edge(edge[0],node_idx)
G_sub.add_edge(node_idx,edge[1])
G_sub.remove_edge(edge[0],edge[1])
height_sub[node_idx] = subdiv_height
node_idx += 1
return G_sub, height_sub
def subdivide_edges(G,height,subdiv_heights):
for h in subdiv_heights:
G, height = subdivide_edges_single_height(G,height,h)
return G, height
def get_heights_and_subdivide_edges(G,height1,height2,mesh):
all_heights = get_heights(height1,height2,mesh)
return subdivide_edges(G,height1,all_heights)
def interleaving_subdivided_trees(T1_sub,height1_sub,T2_sub,height2_sub, verbose = True):
# Input: data from two merge trees
# Output: dictionary of matching data
####
#Get initial data
####
# Get cost matrices and dictionaries
label1 = {n:j for (j,n) in enumerate(T1_sub.nodes())}
label2 = {n:j for (j,n) in enumerate(T2_sub.nodes())}
C1, idx_dict1 = get_ultramatrix(T1_sub,height1_sub,label1)
C2, idx_dict2 = get_ultramatrix(T2_sub,height2_sub,label2)
# Get leaf node labels
leaf_nodes1 = [n for n in T1_sub.nodes() if T1_sub.degree(n) == 1 and n != get_key(height1_sub,max(list(height1_sub.values())))[0]]
leaf_nodes2 = [n for n in T2_sub.nodes() if T2_sub.degree(n) == 1 and n != get_key(height2_sub,max(list(height2_sub.values())))[0]]
# Compute coupling
p1 = ot.unif(C1.shape[0])
p2 = ot.unif(C2.shape[0])
loss_fun = 'square_loss'
d, log = ot.gromov.gromov_wasserstein2(C1,C2,p1,p2,loss_fun)
coup = log['T']
####
#Create list of matched points Pi
####
Pi = []
for leaf in leaf_nodes1:
leaf_node = get_key(idx_dict1,leaf)[0]
# Find where the leaf is matched
matched_node_coup_idx = np.argmax(coup[leaf_node,:])
# Add ordered pair to Pi
Pi.append((leaf,idx_dict2[matched_node_coup_idx]))
for leaf in leaf_nodes2:
leaf_node = get_key(idx_dict2,leaf)[0]
# Find where the leaf is matched
matched_node_coup_idx = np.argmax(coup[:,leaf_node])
# Add ordered pair to Pi
Pi.append((idx_dict1[matched_node_coup_idx],leaf))
Pi = list(set(Pi))
####
# Create new ultramatrices and compute interleaving distance
####
indices_1 = [label1[pair[0]] for pair in Pi]
indices_2 = [label2[pair[1]] for pair in Pi]
C1New = C1[indices_1,:][:,indices_1]
C2New = C2[indices_2,:][:,indices_2]
dist = matrix_ell_infinity_distance(C1New,C2New)
dist_l2 = np.sqrt(np.sum((C1New - C2New)**2))
####
# Collect results for output
####
if verbose:
res = dict()
res['coupling'] = coup
labels1New = dict()
labels2New = dict()
for j, pair in enumerate(Pi):
if pair[0] in labels1New.keys():
labels1New[pair[0]].append(j)
else:
labels1New[pair[0]] = [j]
if pair[1] in labels2New.keys():
labels2New[pair[1]].append(j)
else:
labels2New[pair[1]] = [j]
res['label1'] = labels1New
res['label2'] = labels2New
res['ultra1'] = C1New
res['ultra2'] = C2New
res['dist'] = dist
res['dist_l2'] = dist_l2
res['dist_gw'] = d
res['gw_log'] = log
else:
res = dist
return res
def merge_tree_interleaving_distance(MT1, MT2, mesh, verbose = True, return_subdivided = False):
T1 = MT1.tree
height1 = MT1.height
T2 = MT2.tree
height2 = MT2.height
T1_sub, height1_sub = get_heights_and_subdivide_edges(T1,height1,height2,mesh)
T2_sub, height2_sub = get_heights_and_subdivide_edges(T2,height2,height1,mesh)
res = interleaving_subdivided_trees(T1_sub,height1_sub,T2_sub,height2_sub,verbose = verbose)
if return_subdivided:
MT1_sub = MergeTree(tree = T1_sub, height = height1_sub, simplify = False)
MT2_sub = MergeTree(tree = T2_sub, height = height2_sub, simplify = False)
return MT1_sub, MT2_sub, res
else:
return res
"""
Matching decorated merge trees and estimating interleaving distance
"""
def linkage_to_merge_tree(L,X):
nodeIDs = np.unique(L[:,:2]).astype(int)
num_leaves = X.shape[0]
edge_list = []
height = dict()
for j in range(num_leaves):
height[j] = 0
for j in range(L.shape[0]):
edge_list.append((int(L[j,0]),num_leaves+j))
edge_list.append((int(L[j,1]),num_leaves+j))
height[num_leaves+ j] = L[j,2]
T = nx.Graph()
T.add_nodes_from(nodeIDs)
T.add_edges_from(edge_list)
return T, height
def get_descendent_leaves(T,height,vertex):
root = get_key(height,max(list(height.values())))[0]
leaves = [n for n in T.nodes() if T.degree(n)==1 and n != root]
descendent_leaves = []
for leaf in leaves:
shortest_path = nx.shortest_path(T, source = leaf, target = root)
if vertex in shortest_path:
descendent_leaves.append(leaf)
return descendent_leaves
def truncate_bar(bar,height):
if height <= bar[0]:
truncated_bar = bar