-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathlabel_internal_nodes.py
executable file
·91 lines (86 loc) · 4.01 KB
/
label_internal_nodes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3
'''
Label the internal nodes of a given tree with the individual inside which the corresponding ancestral node existed.
'''
TREESWIFT_IMPORT_ERROR = "Error importing TreeSwift. Install with: pip3 install treeswift"
MUTATION_FILE_ERROR = "If you specify an input mutation tree file, you must also specify an output mutation tree file (and vice-versa)"
try:
from treeswift import read_tree_newick
except:
raise ImportError(TREESWIFT_IMPORT_ERROR)
# main function
if __name__ == "__main__":
# parse args
from sys import stdin,stdout; from gzip import open as gopen; import argparse
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-tn', '--transmission_network', required=True, type=str, help="Input Transmission Network File")
parser.add_argument('-tt', '--tree_time', required=True, type=str, help="Input Time Tree File")
parser.add_argument('-tm', '--tree_mutation', required=False, type=str, default=None, help="Input Mutation Tree File")
parser.add_argument('-ot', '--output_time', required=True, type=str, help="Output Labeled Time Tree File")
parser.add_argument('-om', '--output_mutation', required=False, type=str, default=None, help="Output Mutation Tree File")
args,unknown = parser.parse_known_args()
if (args.tree_mutation is None and args.output_mutation is not None) or (args.tree_mutation is not None and args.output_mutation is None):
raise ValueError(MUTATION_FILE_ERROR)
tt = read_tree_newick(args.tree_time)
tm = None
if args.tree_mutation is not None:
tm = read_tree_newick(args.tree_mutation)
if args.transmission_network.lower().endswith('.gz'):
tn = gopen(args.transmission_network)
else:
tn = open(args.transmission_network)
if args.output_time.lower().endswith('.gz'):
ot = gopen(args.output_time,'wb',9)
else:
ot = open(args.output_time,'w')
if args.output_mutation is not None:
if args.output_mutation.lower().endswith('.gz'):
om = gopen(args.output_mutation,'wb',9)
else:
om = open(args.output_mutation,'w')
# if mutation tree specified, map nodes from time tree to mutation tree
tt2tm = None
if tm is not None:
tm_label2node = {l.label:l for l in tm.traverse_leaves()}
tt2tm = {l:tm_label2node[l.label] for l in tt.traverse_leaves()}
for l_tt in tt.traverse_leaves():
c_tm = tt2tm[l_tt].parent
for c_tt in l_tt.traverse_ancestors(include_self=False):
if c_tt in tt2tm:
break
tt2tm[c_tt] = c_tm; c_tm = c_tm.parent
# read seeds and infection times from transmission network
inf = {None:float('inf')}; seeds = set()
for l in tn:
if isinstance(l,bytes):
u,v,t = l.decode().strip().split('\t')
else:
u,v,t = l.strip().split('\t')
if u == 'None':
seeds.add(v); inf[v] = float('-inf')
elif v not in inf:
inf[v] = float(t)
# label internal nodes
person = dict()
for u in tt.traverse_postorder():
if u.is_leaf():
person[u] = u.label.split('|')[1]
else:
if sum(person[c] in seeds for c in u.children) > 1:
person[u] = None
else:
person[u] = sorted((inf[person[c]],person[c]) for c in u.children)[0][1]
if person[u] is not None:
u.label = person[u]
if tt2tm is not None:
tt2tm[u].label = person[u]
# output resulting tree(s)
if args.output_time.lower().endswith('.gz'):
ot.write(tt.newick().encode()); ot.write(b'\n'); ot.close()
else:
ot.write(tt.newick()); ot.write('\n'); ot.close()
if args.output_mutation is not None:
if args.output_mutation.lower().endswith('.gz'):
om.write(tm.newick().encode()); om.write(b'\n'); om.close()
else:
om.write(tm.newick()); om.write('\n'); om.close()