-
Notifications
You must be signed in to change notification settings - Fork 5
/
ca_disco.py
104 lines (87 loc) · 4.26 KB
/
ca_disco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
import csv
import os
from Bio import AlignIO
from Bio.Align import MultipleSeqAlignment
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import treeswift as ts
from disco import *
def retrieve_alignment(tree, aln_path, format, taxa_set, label_to_species):
"""
Parameters
----------------
tree: single-copy treeswift tree generated by DISCO.
aln_path: path to the PHYLIP formatted alignment of the genes.
The row labels should be a superset of the leafset of 'tree'.
format: file format of the alignment (either "phylip" or "fasta").
taxa_set: set, the taxon set of the entire dataset
delimiter: delimiter used to get the taxa names from the labels
Returns the MSA that corresponds to the input tree.
"""
aln = AlignIO.read(open(aln_path), format)
seq_len = len(aln[0].seq)
blank = "-" * seq_len
whitelist, remaining = set(tree.labels(True, False)), set(taxa_set)
result = MultipleSeqAlignment([])
# you can't get sequences by name from aln objects in biopython
# in a better way as far as I can tell
for record in aln[:,:seq_len]:
if record.id in whitelist:
taxon_name = label_to_species(record.id)
result.append(SeqRecord(record.seq, id=taxon_name))
remaining.remove(taxon_name)
for taxon_name in remaining:
result.append(SeqRecord(Seq(blank), id=str(taxon_name)))
result.sort()
return result
def main(args):
input_alignments = [row for row in csv.reader(open(args.alignment, "r"))]
label_to_species = lambda x:x.split(args.delimiter)[0]
tree_list = ts.read_tree_newick(args.input)
assert not isinstance(tree_list, ts.Tree)
taxa_set = set(label for tree in tree_list
for label in map(label_to_species, tree.labels(True, False)))
# init aln with taxa labels
aln = MultipleSeqAlignment([])
for taxa in taxa_set:
aln.append(SeqRecord(Seq(''), id=taxa))
aln.sort()
partitions = []
p_index = 1
for (aln_file, *_), tree in zip(input_alignments, tree_list):
tree.reroot(get_min_root(tree, label_to_species)[0])
tag(tree,label_to_species)
disco_trees = list(filter(lambda x:x.num_nodes(internal=False) >= args.filter, decompose(tree)))
for dtree in disco_trees:
aln += retrieve_alignment(dtree, aln_file, args.format,taxa_set, label_to_species)
if aln.get_alignment_length() + 1 != p_index:
partitions.append(f"{p_index:d}-{aln.get_alignment_length():d}")
p_index = aln.get_alignment_length() + 1
else:
partitions.append("empty")
if args.partition:
with open(f"{args.output_prefix}-partitions.txt", "w") as f:
assert all(len(x) == 2 for x in input_alignments), "alignment list file format problem"
for partition, (aln_file, model) in zip(partitions, input_alignments):
if partition != "empty":
gene_name = aln_file.split(os.sep)[-1].split('.')[0]
f.write(f"{model}, {gene_name}={partition}\n")
AlignIO.write(aln, f"{args.output_prefix}-aln.{args.format[:3]}", args.format)
if __name__=="__main__":
parser = argparse.ArgumentParser(description="generate concatenation files from gene-family trees using decomposition strategies")
parser.add_argument("-i", "--input", type=str,
help="input tree list file", required=True)
parser.add_argument("-o", "--output-prefix", type=str, required=True,
help="output tree list file")
parser.add_argument("-a", "--alignment", required=True, type=str,
help="alignment files list")
parser.add_argument('-f', '--format', choices=["phylip", "fasta"], required=True,
help="alignment file format")
parser.add_argument('-d', '--delimiter', type=str, default='_',
help="delimiter separating taxon label from the rest of the leaf label.")
parser.add_argument('-m', '--filter', type=int, default=4,
help="exclude decomposed trees with less then X taxa")
parser.add_argument('-p', '--partition', action='store_true',
help="generate partition file")
main(parser.parse_args())