diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 64ec943..21c7bc0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,4 +3,4 @@ repos: rev: stable hooks: - id: black - language_version: python3.6 \ No newline at end of file + language_version: python3.8 diff --git a/make_prg/__init__.py b/make_prg/__init__.py index 4c2e919..964ce26 100644 --- a/make_prg/__init__.py +++ b/make_prg/__init__.py @@ -1,3 +1,11 @@ +# ___Constants/Aliases___ # +from Bio.AlignIO import MultipleSeqAlignment + +MSA = MultipleSeqAlignment +NESTING_LVL = 5 +MIN_MATCH_LEN = 7 + +# ___Version___ # from pkg_resources import get_distribution try: diff --git a/make_prg/__main__.py b/make_prg/__main__.py index ab84914..8df8575 100644 --- a/make_prg/__main__.py +++ b/make_prg/__main__.py @@ -1,9 +1,8 @@ import argparse +import logging -import make_prg - -NESTING_LVL = 5 -MIN_MATCH_LEN = 7 +from make_prg import __version__ +from make_prg.subcommands import prg_from_msa def main(): @@ -13,80 +12,29 @@ def main(): description="script to run make_prg subcommands", ) - parser.add_argument("--version", action="version", version=make_prg.__version__) + parser.add_argument("--version", action="version", version=__version__) subparsers = parser.add_subparsers( title="Available subcommands", help="", metavar="" ) - # _____________________________ prg_from_msa ______________________________# - subparser_prg_from_msa = subparsers.add_parser( - "prg_from_msa", - usage="make_prg prg_from_msa [options] ", - help="Make PRG from multiple sequence alignment", - ) - - subparser_prg_from_msa.add_argument( - "MSA", - action="store", - type=str, - help=( - "Input file: a multiple sequence alignment in supported alignment_format. " - "If not in aligned fasta alignment_format, use -f to input the " - "alignment_format type" - ), - ) - subparser_prg_from_msa.add_argument( - "-f", - "--alignment_format", - dest="alignment_format", - action="store", - default="fasta", - help=( - "Alignment format of MSA, must be a biopython AlignIO input " - "alignment_format. See http://biopython.org/wiki/AlignIO. Default: fasta" - ), - ) - subparser_prg_from_msa.add_argument( - "--max_nesting", - dest="max_nesting", - action="store", - type=int, - default=NESTING_LVL, - help="Maximum number of levels to use for nesting. Default: {}".format( - NESTING_LVL - ), - ) - subparser_prg_from_msa.add_argument( - "--min_match_length", - dest="min_match_length", - action="store", - type=int, - default=MIN_MATCH_LEN, - help=( - "Minimum number of consecutive characters which must be identical for a " - "match. Default: {}".format(MIN_MATCH_LEN) - ), - ) - subparser_prg_from_msa.add_argument( - "-p", "--prefix", dest="output_prefix", action="store", help="Output prefix" - ) - subparser_prg_from_msa.add_argument( - "--no_overwrite", - dest="no_overwrite", - action="store_true", - help="Do not overwrite pre-existing prg file with same name", - ) - subparser_prg_from_msa.add_argument( + parser.add_argument( "-v", "--verbose", dest="verbose", action="store_true", help="Run with high verbosity " "(debug level logging)", ) - subparser_prg_from_msa.set_defaults(func=make_prg.subcommands.prg_from_msa.run) + + prg_from_msa.register_parser(subparsers) args = parser.parse_args() + if args.verbose: + log_level = logging.DEBUG + else: + log_level = logging.INFO + logging.basicConfig(level=log_level, handlers=[]) + if hasattr(args, "func"): args.func(args) else: diff --git a/make_prg/exceptions.py b/make_prg/exceptions.py deleted file mode 100644 index fbd0cef..0000000 --- a/make_prg/exceptions.py +++ /dev/null @@ -1,2 +0,0 @@ -class ClusteringError(Exception): - pass diff --git a/make_prg/interval_partition.py b/make_prg/interval_partition.py new file mode 100644 index 0000000..0511661 --- /dev/null +++ b/make_prg/interval_partition.py @@ -0,0 +1,240 @@ +""" +Code responsible for converting a consensus string into a set of disjoint +match/non_match intervals. +""" +from enum import Enum, auto +from typing import List, Tuple, Optional + +from make_prg import MSA + +from make_prg.seq_utils import get_interval_seqs, is_non_match, has_empty_sequence + + +class PartitioningError(Exception): + pass + + +class IntervalType(Enum): + Match = auto() + NonMatch = auto() + + @classmethod + def from_char(cls, letter: str) -> "IntervalType": + if letter == "*": + return IntervalType.NonMatch + else: + return IntervalType.Match + + +def is_type(letter: str, interval_type: IntervalType) -> bool: + if IntervalType.from_char(letter) is interval_type: + return True + else: + return False + + +class Interval: + """Stores a closed interval [a,b]""" + + def __init__(self, it_type: IntervalType, start: int, stop: int = None): + self.type = it_type + self.start = start + if stop is not None: + assert stop >= start + self.stop = stop if stop is not None else start + + def modify_by(self, left_delta: int, right_delta: int): + self.start += left_delta + self.stop += right_delta + + def contains(self, position: int): + return self.start <= position <= self.stop + + def __len__(self) -> int: + return self.stop - self.start + 1 + + def __lt__(self, other: "Interval") -> bool: + return self.start < other.start + + def __eq__(self, other: "Interval") -> bool: + return ( + self.start == other.start + and self.stop == other.stop + and self.type is other.type + ) + + def __repr__(self): + return f"[{self.start}, {self.stop}]" + + +Intervals = List[Interval] + + +class IntervalPartitioner: + """Produces a list of intervals in which we have + consensus sequence longer than min_match_length, and + a list of the non-match intervals left.""" + + def __init__(self, consensus_string: str, min_match_length: int, alignment: MSA): + self._match_intervals: Intervals = list() + self._non_match_intervals: Intervals = list() + self.mml = min_match_length + + if len(consensus_string) < self.mml: + # In this case, a match of less than the min_match_length gets counted + # as a match (usually, it counts as a non_match) + it_type = IntervalType.Match + if any(map(is_non_match, consensus_string)): + it_type = IntervalType.NonMatch + self._append(Interval(it_type, 0, len(consensus_string) - 1)) + + else: + cur_interval = self._new_interval(consensus_string[0], 0) + + for i, letter in enumerate(consensus_string[1:], start=1): + if is_type(letter, cur_interval.type): + cur_interval.modify_by(0, 1) # simple interval extension + else: + new_interval = self._add_interval(cur_interval, alignment) + if new_interval is None: + cur_interval = self._new_interval(letter, i) + else: + cur_interval = new_interval + self._add_interval(cur_interval, alignment, end=True) + + self.enforce_multisequence_nonmatch_intervals( + self._match_intervals, self._non_match_intervals, alignment + ) + self.enforce_alignment_interval_bijection( + self._match_intervals, + self._non_match_intervals, + alignment.get_alignment_length(), + ) + + def get_intervals(self) -> Tuple[Intervals, Intervals, Intervals]: + return ( + sorted(self._match_intervals), + sorted(self._non_match_intervals), + sorted(self._match_intervals + self._non_match_intervals), + ) + + def _new_interval(self, letter: str, start_pos: int) -> Interval: + return Interval(IntervalType.from_char(letter), start_pos) + + def _append(self, interval: Interval): + if interval.type is IntervalType.Match: + self._match_intervals.append(interval) + else: + self._non_match_intervals.append(interval) + + def _pop(self, it_type: IntervalType) -> Interval: + if it_type is IntervalType.Match: + return self._match_intervals.pop() + else: + return self._non_match_intervals.pop() + + def _add_interval( + self, interval: Interval, alignment: MSA, end: bool = False + ) -> Optional[Interval]: + """ + i)If we are given a match interval < min_match_length, we return an extended non_match interval + ii)If we are given a non_match interval containing 1+ empty sequence, we pad it with + previous match_interval, if any, to avoid empty alleles in resulting prg. + """ + if interval.type is IntervalType.Match: + # The +1 is because we also extend the non_match interval + if len(interval) < self.mml: + try: + last_non_match = self._pop(IntervalType.NonMatch) + last_non_match.modify_by(0, len(interval) + 1) + except IndexError: + last_non_match = Interval( + IntervalType.NonMatch, interval.start, interval.stop + 1 + ) + if end: # If this is final call, go to append the interval + last_non_match.modify_by(0, -1) + self._append(last_non_match) + return last_non_match + else: + if len(self._match_intervals) > 0 and has_empty_sequence( + alignment, (interval.start, interval.stop) + ): + # Pad interval with sequence to avoid empty alleles + len_match = len(self._match_intervals[-1]) + if len_match - 1 < self.mml: + # Case: match is now too small, converted to non_match + self._match_intervals.pop() + interval.modify_by(-1 * len_match, 0) + if len(self._non_match_intervals) > 0: + # Case: merge previous non_match with this non_match + self._non_match_intervals[-1].modify_by(0, len(interval)) + return None + else: + self._match_intervals[-1].modify_by(0, -1) + interval.modify_by(-1, 0) + + self._append(interval) + return None + + @classmethod + def enforce_multisequence_nonmatch_intervals( + cls, match_intervals: Intervals, non_match_intervals: Intervals, alignment: MSA + ) -> None: + """ + Goes through non-match intervals and makes sure there is more than one sequence there, else makes it a match + interval. + Modifies the intervals in-place. + Example reasons for such a conversion to occur: + - 'N' in a sequence causes it to be filtered out, and left with a single useable sequence + - '-' in sequences causes them to appear different, but they are the same + """ + if len(alignment) == 0: # For testing convenience + return + for i in reversed(range(len(non_match_intervals))): + interval = non_match_intervals[i] + interval_alignment = alignment[:, interval.start : interval.stop + 1] + interval_seqs = get_interval_seqs(interval_alignment) + if len(interval_seqs) < 2: + changed_interval = non_match_intervals[i] + match_intervals.append( + Interval( + IntervalType.Match, + changed_interval.start, + changed_interval.stop, + ) + ) + non_match_intervals.pop(i) + + @classmethod + def enforce_alignment_interval_bijection( + cls, + match_intervals: Intervals, + non_match_intervals: Intervals, + alignment_length: int, + ): + """ + Check each position in an alignment is in one, and one only, (match or non_match) interval + """ + for i in range(alignment_length): + count_match = 0 + for interval in match_intervals: + if interval.contains(i): + count_match += 1 + count_non_match = 0 + for interval in non_match_intervals: + if interval.contains(i): + count_non_match += 1 + + if count_match > 1 or count_non_match > 1: + raise PartitioningError( + f"Failed interval partitioning: position {i}" + " appears in more than one interval" + ) + if ( + not count_match ^ count_non_match + ): # test fails if they are the same integer + msg = ["neither", "nor"] if count_match == 0 else ["both", "and"] + raise PartitioningError( + "Failed interval partitioning: alignment position %d" + "classified as %s match %s non-match " % (i, msg[0], msg[1]) + ) diff --git a/make_prg/io_utils.py b/make_prg/io_utils.py index a03ae6b..1eff905 100644 --- a/make_prg/io_utils.py +++ b/make_prg/io_utils.py @@ -5,12 +5,11 @@ from Bio import AlignIO +from make_prg import MSA from make_prg.prg_encoder import PrgEncoder, PRG_Ints -def load_alignment_file( - msa_file: str, alignment_format: str -) -> AlignIO.MultipleSeqAlignment: +def load_alignment_file(msa_file: str, alignment_format: str) -> MSA: logging.info("Read from MSA file %s", msa_file) if ".gz" in msa_file: logging.debug("MSA is gzipped") diff --git a/make_prg/make_prg_from_msa.py b/make_prg/make_prg_from_msa.py index d45de20..3e0137f 100644 --- a/make_prg/make_prg_from_msa.py +++ b/make_prg/make_prg_from_msa.py @@ -1,24 +1,25 @@ import logging from collections import defaultdict from typing import List -from itertools import chain import numpy as np -from Bio.AlignIO import MultipleSeqAlignment from sklearn.cluster import KMeans +from make_prg import MSA from make_prg.io_utils import load_alignment_file from make_prg.seq_utils import ( + ambiguous_bases, remove_duplicates, - remove_gaps, get_interval_seqs, - ambiguous_bases, + NONMATCH, ) +from make_prg.interval_partition import IntervalPartitioner class AlignedSeq(object): """ - Object based on a set of aligned sequences. Note min_match_length must be strictly greater than max_nesting + 1. + Object based on a set of aligned sequences. + Note min_match_length must be strictly greater than max_nesting + 1. """ def __init__( @@ -39,18 +40,25 @@ def __init__( self.nesting_level = nesting_level self.min_match_length = min_match_length self.site = site - self.alignment = alignment + self.alignment: MSA = alignment if self.alignment is None: self.alignment = load_alignment_file(msa_file, alignment_format) self.interval = interval - self.consensus = self.get_consensus() + self.consensus = self.get_consensus(self.alignment) self.length = len(self.consensus) - (self.match_intervals, self.non_match_intervals) = self.interval_partition() - self.check_nonmatch_intervals() - self.all_intervals = self.match_intervals + self.non_match_intervals - logging.info("Non match intervals: %s", self.non_match_intervals) - self.all_intervals.sort() + ( + self.match_intervals, + self.non_match_intervals, + self.all_intervals, + ) = IntervalPartitioner( + self.consensus, self.min_match_length, self.alignment + ).get_intervals() + logging.info( + "match intervals: %s; non_match intervals: %s", + self.match_intervals, + self.non_match_intervals, + ) # properties for stats self.subAlignedSeqs = {} @@ -67,135 +75,30 @@ def __init__( else: self.prg = self.get_prg() - def get_consensus(self): - """Given a set of aligment records from AlignIO, creates - a consensus string. + @classmethod + def get_consensus(cls, alignment: MSA): + """ Produces a 'consensus string' from an MSA: at each position of the + MSA, the string has a base if all aligned sequences agree, and a "*" if not. IUPAC ambiguous bases result in non-consensus and are later expanded in the prg. N results in consensus at that position unless they are all N.""" - first_string = str(self.alignment[0].seq) consensus_string = "" - for i, letter in enumerate(first_string): - consensus = True - for record in self.alignment: - if letter == "N" or record.seq[i] == "N": - if letter == "N" and record.seq[i] != "N": - letter = record.seq[i] - continue - if letter != record.seq[i] or record.seq[i] in ambiguous_bases: - consensus = False - break - if consensus and letter != "N": - consensus_string += letter - else: - consensus_string += "*" - assert len(first_string) == len(consensus_string) - return consensus_string - - def interval_partition(self): - """Return a list of intervals in which we have - consensus sequence longer than min_match_length, and - a list of the non-match intervals left.""" - match_intervals = [] - non_match_intervals = [] - match_count, match_start, non_match_start = 0, 0, 0 - - logging.debug("consensus: %s" % self.consensus) - for i in range(self.length): - letter = self.consensus[i] - if letter != "*": - # In a match region. - if match_count == 0: - match_start = i - match_count += 1 - elif match_count > 0: - # Have reached a non-match. Check if previous match string is long enough to add to match_regions - match_string = remove_gaps( - self.consensus[match_start : match_start + match_count] - ) - match_len = len(match_string) - logging.debug("have match string %s" % match_string) - - if match_len >= self.min_match_length: - if non_match_start < match_start: - non_match_intervals.append([non_match_start, match_start - 1]) - logging.debug( - f"add non-match interval [{non_match_start},{match_start - 1}]" - ) - end = match_start + match_count - 1 - match_intervals.append([match_start, end]) - logging.debug(f"add match interval [{match_start},{end}]") - non_match_start = i - match_count = 0 - match_start = non_match_start - - end = self.length - 1 - if self.length < self.min_match_length: - # Special case: a short sequence can still get classified as a match interval - added_interval = "match" if "*" in self.consensus else "non_match" - if added_interval == "match": - match_intervals.append([0, end]) + for i in range(alignment.get_alignment_length()): + column = set([record.seq[i] for record in alignment]) + column = column.difference({"N"}) + if ( + len(ambiguous_bases.intersection(column)) > 0 + or len(column) != 1 + or column == {"-"} + ): + consensus_string += NONMATCH else: - non_match_intervals.append([0, end]) - logging.debug(f"add whole short {added_interval} interval [0,{end}]") - match_count = 0 - non_match_start = end + 1 - - # At end add last intervals - if match_count > 0: - if match_count >= self.min_match_length: - match_intervals.append([match_start, end]) - logging.debug(f"add final match interval [{match_start},{end}]") - if non_match_start < match_start: - end = match_start - 1 - if match_count != self.length and non_match_start <= end: - non_match_intervals.append([non_match_start, end]) - logging.debug(f"add non-match interval [{non_match_start},{end}]") - - # check all stretches of consensus are in an interval, and intervals don't overlap - for i in range(self.length): - count_match = 0 - for interval in match_intervals: - if interval[0] <= i <= interval[1]: - count_match += 1 - count_non_match = 0 - for interval in non_match_intervals: - if interval[0] <= i <= interval[1]: - count_non_match += 1 - - assert count_match | count_non_match, ( - "Failed to correctly identify match intervals: position %d " - "appeared in both/neither match and non-match intervals" % i - ) - assert count_match + count_non_match == 1, ( - "Failed to correctly identify match intervals: position " - "%d appeared in %d intervals" % (i, count_match + count_non_match) - ) + consensus_string += column.pop() - return match_intervals, non_match_intervals - - def check_nonmatch_intervals(self): - """ - Goes through non-match intervals and makes sure there is more than one sequence there, else makes it a match - interval. - Example reasons for such a conversion to occur: - - 'N' in a sequence causes it to be filtered out, and left with a single useable sequence - - '-' in sequences causes them to appear different, but they are the same - """ - for i in reversed(range(len(self.non_match_intervals))): - interval = self.non_match_intervals[i] - interval_alignment = self.alignment[:, interval[0] : interval[1] + 1] - interval_seqs = get_interval_seqs(interval_alignment) - if len(interval_seqs) < 2: - self.match_intervals.append(self.non_match_intervals[i]) - self.non_match_intervals.pop(i) - self.match_intervals.sort() + return consensus_string @classmethod def kmeans_cluster_seqs_in_interval( - self, - interval: List[int], - alignment: MultipleSeqAlignment, - min_match_length: int, + self, interval: List[int], alignment: MSA, min_match_length: int, ) -> List[List[str]]: """Divide sequences in interval into subgroups of similar sequences.""" interval_alignment = alignment[:, interval[0] : interval[1] + 1] @@ -321,26 +224,24 @@ def kmeans_cluster_seqs_in_interval( @classmethod def get_sub_alignment_by_list_id( - self, id_list: List[str], alignment: MultipleSeqAlignment, interval=None + self, id_list: List[str], alignment: MSA, interval=None ): list_records = [record for record in alignment if record.id in id_list] - sub_alignment = MultipleSeqAlignment(list_records) + sub_alignment = MSA(list_records) if interval: sub_alignment = sub_alignment[:, interval[0] : interval[1] + 1] return sub_alignment def get_prg(self): prg = "" - # last_char = None - # skip_char = False for interval in self.all_intervals: if interval in self.match_intervals: # all seqs are not necessarily exactly the same: some can have 'N' # thus still process all of them, to get the one with no 'N'. - sub_alignment = self.alignment[:, interval[0] : interval[1] + 1] + sub_alignment = self.alignment[:, interval.start : interval.stop + 1] seqs = get_interval_seqs(sub_alignment) - assert 0 < len(seqs) <= 1, "Got >1 filtered sequences in match interval" + assert len(seqs) == 1, "Got >1 filtered sequences in match interval" seq = seqs[0] prg += seq @@ -352,36 +253,30 @@ def get_prg(self): # Define the variant seqs to add if (self.nesting_level == self.max_nesting) or ( - interval[1] - interval[0] <= self.min_match_length + interval.stop - interval.start <= self.min_match_length ): - # Have reached max nesting level, just add all variants in interval. logging.debug( "Have reached max nesting level or have a small variant site, so add all variant " "sequences in interval." ) - sub_alignment = self.alignment[:, interval[0] : interval[1] + 1] - logging.debug( - "Variant seqs found: %s" - % list( - remove_duplicates( - [str(record.seq) for record in sub_alignment] - ) - ) - ) + sub_alignment = self.alignment[ + :, interval.start : interval.stop + 1 + ] variant_prgs = get_interval_seqs(sub_alignment) - logging.debug("Which is equivalent to: %s" % variant_prgs) + logging.debug(f"Variant seqs found: {variant_prgs}") else: - # divide sequences into subgroups and define prg for each subgroup. logging.debug( "Divide sequences into subgroups and define prg for each subgroup." ) recur = True id_lists = self.kmeans_cluster_seqs_in_interval( - interval, self.alignment, self.min_match_length + [interval.start, interval.stop], + self.alignment, + self.min_match_length, ) list_sub_alignments = [ self.get_sub_alignment_by_list_id( - id_list, self.alignment, interval + id_list, self.alignment, [interval.start, interval.stop] ) for id_list in id_lists ] @@ -392,8 +287,8 @@ def get_prg(self): "Clustering did not group any sequences together, each seq is a cluster" ) recur = False - elif interval[0] not in self.subAlignedSeqs: - self.subAlignedSeqs[interval[0]] = [] + elif interval.start not in self.subAlignedSeqs: + self.subAlignedSeqs[interval.start] = [] logging.debug( "subAlignedSeqs now has keys: %s", list(self.subAlignedSeqs.keys()), @@ -401,7 +296,7 @@ def get_prg(self): else: logging.debug( "subAlignedSeqs already had key %d in keys: %s. This shouldn't happen.", - interval[0], + interval.start, list(self.subAlignedSeqs.keys()), ) @@ -421,11 +316,7 @@ def get_prg(self): self.site = sub_aligned_seq.site if recur: - # logging.debug("None not in snp_scores - try to add sub__aligned_seq to list in - # dictionary") - self.subAlignedSeqs[interval[0]].append(sub_aligned_seq) - # logging.debug("Length of subAlignedSeqs[%d] is %d", interval[0], - # len(self.subAlignedSeqs[interval[0]])) + self.subAlignedSeqs[interval.start].append(sub_aligned_seq) assert num_clusters == len(variant_prgs), ( "I don't seem to have a sub-prg sequence for all parts of the partition - there are %d " "classes in partition, and %d variant seqs" @@ -437,18 +328,13 @@ def get_prg(self): list(remove_duplicates(variant_prgs)) ), "have repeat variant seqs" - # Add the variant seqs to the prg - prg += "%s%d%s" % ( - self.delim_char, - site_num, - self.delim_char, - ) # considered making it so start of prg was not delim_char, - # but that would defeat the point if it + # Add the variant seqs to the prg. + prg += f"{self.delim_char}{site_num}{self.delim_char}" while len(variant_prgs) > 1: prg += variant_prgs.pop(0) - prg += "%s%d%s" % (self.delim_char, site_num + 1, self.delim_char) + prg += f"{self.delim_char}{site_num + 1}{self.delim_char}" prg += variant_prgs.pop() - prg += "%s%d%s" % (self.delim_char, site_num, self.delim_char) + prg += f"{self.delim_char}{site_num}{self.delim_char}" return prg @@ -491,7 +377,7 @@ def max_nesting_level_reached(self): def prop_in_match_intervals(self): length_match_intervals = 0 for interval in self.match_intervals: - length_match_intervals += interval[1] - interval[0] + 1 + length_match_intervals += interval.stop - interval.start + 1 return length_match_intervals / float(self.length) @property diff --git a/make_prg/seq_utils.py b/make_prg/seq_utils.py index 1b700f1..0139272 100644 --- a/make_prg/seq_utils.py +++ b/make_prg/seq_utils.py @@ -1,8 +1,19 @@ import logging -from typing import Generator, Sequence +from typing import Generator, Sequence, Tuple import itertools -from Bio import AlignIO +from make_prg import MSA + +NONMATCH = "*" +GAP = "-" + + +def is_non_match(letter: str): + return letter == NONMATCH + + +def is_gap(letter: str): + return letter == GAP def remove_duplicates(seqs: Sequence) -> Generator: @@ -14,10 +25,6 @@ def remove_duplicates(seqs: Sequence) -> Generator: yield x -def remove_gaps(sequence: str) -> str: - return sequence.replace("-", "") - - iupac = { "R": "GA", "Y": "TC", @@ -35,7 +42,15 @@ def remove_gaps(sequence: str) -> str: ambiguous_bases = allowed_bases.difference(standard_bases) -def get_interval_seqs(interval_alignment: AlignIO.MultipleSeqAlignment): +def has_empty_sequence(alignment: MSA, interval: Tuple[int, int]) -> bool: + sub_alignment = alignment[:, interval[0] : interval[1] + 1] + for record in sub_alignment: + if all(map(is_gap, record.seq)): + return True + return False + + +def get_interval_seqs(interval_alignment: MSA): """ Replace - with nothing, remove seqs containing N or other non-allowed letters and duplicate sequences containing RYKMSW, replacing with AGCT alternatives @@ -60,7 +75,7 @@ def get_interval_seqs(interval_alignment: AlignIO.MultipleSeqAlignment): if len(expanded_set) == 0: logging.warning( - "WARNING: Every sequence must have contained an N in this slice - redo sequence curation because this is nonsense" + "WARNING: Every sequence must have contained an N in this slice - redo sequence curation" ) logging.warning(f'Sequences were: {" ".join(callback_seqs)}') logging.warning( diff --git a/make_prg/subcommands/prg_from_msa.py b/make_prg/subcommands/prg_from_msa.py index 06b365c..4738b58 100644 --- a/make_prg/subcommands/prg_from_msa.py +++ b/make_prg/subcommands/prg_from_msa.py @@ -2,7 +2,70 @@ import os from pathlib import Path -from make_prg import make_prg_from_msa, io_utils +from make_prg import make_prg_from_msa, io_utils, NESTING_LVL, MIN_MATCH_LEN + + +def register_parser(subparsers): + subparser_prg_from_msa = subparsers.add_parser( + "prg_from_msa", + usage="make_prg prg_from_msa [options] ", + help="Make PRG from multiple sequence alignment", + ) + + subparser_prg_from_msa.add_argument( + "MSA", + action="store", + type=str, + help=( + "Input file: a multiple sequence alignment in supported alignment_format. " + "If not in aligned fasta alignment_format, use -f to input the " + "alignment_format type" + ), + ) + subparser_prg_from_msa.add_argument( + "-f", + "--alignment_format", + dest="alignment_format", + action="store", + default="fasta", + help=( + "Alignment format of MSA, must be a biopython AlignIO input " + "alignment_format. See http://biopython.org/wiki/AlignIO. Default: fasta" + ), + ) + subparser_prg_from_msa.add_argument( + "--max_nesting", + dest="max_nesting", + action="store", + type=int, + default=NESTING_LVL, + help="Maximum number of levels to use for nesting. Default: {}".format( + NESTING_LVL + ), + ) + subparser_prg_from_msa.add_argument( + "--min_match_length", + dest="min_match_length", + action="store", + type=int, + default=MIN_MATCH_LEN, + help=( + "Minimum number of consecutive characters which must be identical for a " + "match. Default: {}".format(MIN_MATCH_LEN) + ), + ) + subparser_prg_from_msa.add_argument( + "-p", "--prefix", dest="output_prefix", action="store", help="Output prefix" + ) + subparser_prg_from_msa.add_argument( + "--no_overwrite", + dest="no_overwrite", + action="store_true", + help="Do not overwrite pre-existing prg file with same name", + ) + subparser_prg_from_msa.set_defaults(func=run) + + return subparser_prg_from_msa def run(options): @@ -18,23 +81,17 @@ def run(options): options.min_match_length, ) - if options.verbose: - log_level = logging.DEBUG - msg = "Using debug logging" - else: - log_level = logging.INFO - msg = "Using info logging" - + # Set up file logging log_file = f"{prefix}.log" if os.path.exists(log_file): os.unlink(log_file) - logging.basicConfig( - filename=log_file, - level=log_level, - format="%(asctime)s %(message)s", - datefmt="%d/%m/%Y %I:%M:%S", + formatter = logging.Formatter( + fmt="%(levelname)s %(asctime)s %(message)s", datefmt="%d/%m/%Y %I:%M:%S" ) - logging.info(msg) + handler = logging.FileHandler(log_file) + handler.setFormatter(formatter) + logging.getLogger().addHandler(handler) + logging.info( "Input parameters max_nesting: %d, min_match_length: %d", options.max_nesting, diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..a4b2330 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,14 @@ +from typing import List + +from Bio.Seq import Seq +from Bio.SeqRecord import SeqRecord + +from make_prg import MSA + + +def make_alignment(seqs: List[str], ids: List[str] = None) -> MSA: + if ids is None: + seqrecords = [SeqRecord(Seq(seq), id=f"s_{i}") for i, seq in enumerate(seqs)] + else: + seqrecords = [SeqRecord(Seq(seq), id=ID) for seq, ID in zip(seqs, ids)] + return MSA(seqrecords) diff --git a/tests/test_interval_partition.py b/tests/test_interval_partition.py new file mode 100644 index 0000000..d4ca450 --- /dev/null +++ b/tests/test_interval_partition.py @@ -0,0 +1,167 @@ +from unittest import TestCase +from typing import List + +from tests import make_alignment, MSA +from make_prg.interval_partition import ( + IntervalType, + Interval, + IntervalPartitioner, + PartitioningError, +) + +Lists = List[List[int]] + +Match = IntervalType.Match +NonMatch = IntervalType.NonMatch + + +def make_typed_intervals(lists: Lists, it_type: IntervalType): + result = list() + for elem in lists: + result.append(Interval(it_type, elem[0], elem[1])) + return result + + +def make_intervals(match_lists: Lists, non_match_lists: Lists): + return ( + make_typed_intervals(match_lists, Match), + make_typed_intervals(non_match_lists, NonMatch), + ) + + +class TestIntervalConsistency(TestCase): + def test_nonmatch_interval_switching_indels(self): + """Because the sequences are the same, despite different alignment""" + alignment = make_alignment(["A---A", "A-A--"]) + match_intervals, non_match_intervals = make_intervals([], [[0, 5]]) + IntervalPartitioner.enforce_multisequence_nonmatch_intervals( + match_intervals, non_match_intervals, alignment + ) + self.assertEqual(match_intervals, make_typed_intervals([[0, 5]], Match)) + self.assertEqual(non_match_intervals, []) + + def test_nonmatch_interval_switching_Ns(self): + """'N's make sequences get removed""" + alignment = make_alignment(["ANAAA", "ATAAT"]) + match_intervals, non_match_intervals = make_intervals([], [[0, 5]]) + IntervalPartitioner.enforce_multisequence_nonmatch_intervals( + match_intervals, non_match_intervals, alignment + ) + self.assertEqual(match_intervals, make_typed_intervals([[0, 5]], Match)) + self.assertEqual(non_match_intervals, []) + + def test_position_in_several_intervals_fails(self): + match_intervals = make_typed_intervals([[0, 1], [1, 2]], Match) + with self.assertRaises(PartitioningError): + IntervalPartitioner.enforce_alignment_interval_bijection( + match_intervals, [], 3 + ) + + def test_position_in_no_interval_fails(self): + match_intervals = make_typed_intervals([[0, 1]], Match) + with self.assertRaises(PartitioningError): + IntervalPartitioner.enforce_alignment_interval_bijection( + match_intervals, [], 3 + ) + + def test_position_in_match_and_nonmatch_intervals_fails(self): + match_intervals, nmatch_intervals = make_intervals([[0, 2]], [[2, 3]]) + with self.assertRaises(PartitioningError): + IntervalPartitioner.enforce_alignment_interval_bijection( + match_intervals, nmatch_intervals, 4 + ) + + def test_bijection_respected_passes(self): + match_intervals, nmatch_intervals = make_intervals([[0, 2], [5, 10]], [[3, 4]]) + IntervalPartitioner.enforce_alignment_interval_bijection( + match_intervals, nmatch_intervals, 11 + ) + + +class TestIntervalPartitioning(TestCase): + def test_all_non_match(self): + tester = IntervalPartitioner("******", min_match_length=3, alignment=MSA([])) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, []) + self.assertEqual(non_match, make_typed_intervals([[0, 5]], NonMatch)) + + def test_all_match(self): + tester = IntervalPartitioner("ATATAAA", min_match_length=3, alignment=MSA([])) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, make_typed_intervals([[0, 6]], Match)) + self.assertEqual(non_match, []) + + def test_short_match_counted_as_non_match(self): + tester = IntervalPartitioner("AT***", min_match_length=3, alignment=MSA([])) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, []) + self.assertEqual(non_match, make_typed_intervals([[0, 4]], NonMatch)) + + def test_match_non_match_match(self): + tester = IntervalPartitioner("ATT**AAAC", min_match_length=3, alignment=MSA([])) + match, non_match, all_match = tester.get_intervals() + expected_matches = make_typed_intervals([[0, 2], [5, 8]], Match) + expected_non_matches = make_typed_intervals([[3, 4]], NonMatch) + self.assertEqual(match, expected_matches) + self.assertEqual(non_match, expected_non_matches) + # Check interval sorting works + self.assertEqual( + all_match, + [expected_matches[0], expected_non_matches[0], expected_matches[1]], + ) + + def test_end_in_non_match(self): + tester = IntervalPartitioner( + "**ATT**AAA*C", min_match_length=3, alignment=MSA([]) + ) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, make_typed_intervals([[2, 4], [7, 9]], Match)) + self.assertEqual( + non_match, make_typed_intervals([[0, 1], [5, 6], [10, 11]], NonMatch) + ) + + def test_consensus_smaller_than_min_match_len(self): + """ + Usually, a match smaller than min_match_length counts as non-match, + but if the whole string is smaller than min_match_length, counts as match. + """ + tester1 = IntervalPartitioner("TTATT", min_match_length=7, alignment=MSA([])) + match, non_match, _ = tester1.get_intervals() + self.assertEqual(match, make_typed_intervals([[0, 4]], Match)) + self.assertEqual(non_match, []) + + tester2 = IntervalPartitioner("T*ATT", min_match_length=7, alignment=MSA([])) + match, non_match, _ = tester2.get_intervals() + self.assertEqual(match, []) + self.assertEqual(non_match, make_typed_intervals([[0, 4]], NonMatch)) + + def test_avoid_empty_alleles_long_match(self): + """ + If we let the non-match interval be only [4,5], + this would result in an empty allele in the prg, + so require padding using the preceding match sequence + """ + msa = make_alignment(["TTAAGGTTT", "TTAA--TTT"]) + tester = IntervalPartitioner("TTAA**TTT", min_match_length=3, alignment=msa) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, make_typed_intervals([[0, 2], [6, 8]], Match)) + self.assertEqual(non_match, make_typed_intervals([[3, 5]], NonMatch)) + + def test_avoid_empty_alleles_short_match(self): + """ + Padding behaviour also expected, but now the leading match interval becomes too + short and collapses to a non_match interval + """ + msa = make_alignment(["TTAGGTTT", "TTA--TTT"]) + tester = IntervalPartitioner("TTA**TTT", min_match_length=3, alignment=msa) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, make_typed_intervals([[5, 7]], Match)) + self.assertEqual(non_match, make_typed_intervals([[0, 4]], NonMatch)) + + def test_avoid_empty_alleles_previous_non_match_merged(self): + """Edge case of collapsed match interval, part 2""" + msa = make_alignment(["CCTTAGGTTT", "AATTA--TTT"]) + tester = IntervalPartitioner("**TTA**TTT", min_match_length=3, alignment=msa) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, make_typed_intervals([[7, 9]], Match)) + self.assertEqual(non_match, make_typed_intervals([[0, 6]], NonMatch)) diff --git a/tests/test_make_prg_from_msa.py b/tests/test_make_prg_from_msa.py index da2693a..9ec6f58 100644 --- a/tests/test_make_prg_from_msa.py +++ b/tests/test_make_prg_from_msa.py @@ -1,81 +1,75 @@ import os import random -from unittest import TestCase, mock, skip +from unittest import TestCase, skip -from Bio.AlignIO import MultipleSeqAlignment from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord from make_prg.make_prg_from_msa import AlignedSeq -from make_prg.exceptions import ClusteringError from make_prg.seq_utils import standard_bases +from tests import make_alignment, MSA this_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) data_dir = os.path.join(this_dir, "tests", "data", "make_prg_from_msa") -@mock.patch.object(AlignedSeq, "check_nonmatch_intervals") -@mock.patch.object(AlignedSeq, "get_prg") -@mock.patch.object(AlignedSeq, "get_consensus") -class TestIntervalPartitioning(TestCase): - def test_all_non_match(self, get_consensus, _, __): - get_consensus.return_value = "******" - tester = AlignedSeq("_", alignment="_", min_match_length=3) - match, non_match = tester.interval_partition() - self.assertEqual(match, []) - self.assertEqual(non_match, [[0, 5]]) - - def test_all_match(self, get_consensus, _, __): - get_consensus.return_value = "ATATAAA" - tester = AlignedSeq("_", alignment="_", min_match_length=3) - match, non_match = tester.interval_partition() - self.assertEqual(match, [[0, 6]]) - self.assertEqual(non_match, []) - - def test_short_match_counted_as_non_match(self, get_consensus, _, __): - get_consensus.return_value = "AT***" - tester = AlignedSeq("_", alignment="_", min_match_length=3) - match, non_match = tester.interval_partition() - self.assertEqual(match, []) - self.assertEqual(non_match, [[0, 4]]) - - def test_match_non_match_match(self, get_consensus, _, __): - get_consensus.return_value = "ATT**AAAC" - tester = AlignedSeq("_", alignment="_", min_match_length=3) - match, non_match = tester.interval_partition() - self.assertEqual(match, [[0, 2], [5, 8]]) - self.assertEqual(non_match, [[3, 4]]) - - def test_end_in_non_match(self, get_consensus, _, __): - get_consensus.return_value = "**ATT**AAA*C" - tester = AlignedSeq("_", alignment="_", min_match_length=3) - match, non_match = tester.interval_partition() - self.assertEqual(match, [[2, 4], [7, 9]]) - self.assertEqual(non_match, [[0, 1], [5, 6], [10, 11]]) +class TestConsensusString(TestCase): + def test_all_match(self): + alignment = make_alignment(["AATTA", "AATTA"]) + result = AlignedSeq.get_consensus(alignment) + self.assertEqual(result, "AATTA") + + def test_mixed_match_nonmatch(self): + alignment = make_alignment(["AAGTA", "CATTA"]) + result = AlignedSeq.get_consensus(alignment) + self.assertEqual(result, "*A*TA") + + def test_indel_nonmatch(self): + alignment = make_alignment(["AAAA", "A--A"]) + result = AlignedSeq.get_consensus(alignment) + self.assertEqual(result, "A**A") + + def test_IUPACAmbiguous_nonmatch(self): + alignment = make_alignment(["RYA", "RTA"]) + result = AlignedSeq.get_consensus(alignment) + self.assertEqual(result, "**A") + + def test_N_special_treatment(self): + """ + i)A and N at pos 2 are different, but still consensus + ii)N and N at pos 0 are same, but not consensus""" + alignment = make_alignment(["NTN", "NTA"]) + result = AlignedSeq.get_consensus(alignment) + self.assertEqual(result, "*TA") + + def test_all_gap_nonmatch(self): + alignment = make_alignment(["A--A", "A--A"]) + result = AlignedSeq.get_consensus(alignment) + self.assertEqual(result, "A**A") class TestKmeansClusters(TestCase): def test_one_seq_returns_single_id(self): - alignment = MultipleSeqAlignment([SeqRecord(Seq("AAAT"), id="s1")]) + alignment = MSA([SeqRecord(Seq("AAAT"), id="s1")]) result = AlignedSeq.kmeans_cluster_seqs_in_interval([0, 3], alignment, 1) self.assertEqual(result, [["s1"]]) def test_two_seqs_one_below_min_match_len_separate_clusters(self): - alignment = MultipleSeqAlignment( + alignment = MSA( [SeqRecord(Seq("AATTTAT"), id="s1"), SeqRecord(Seq("AA---AT"), id="s2")] ) result = AlignedSeq.kmeans_cluster_seqs_in_interval([0, 5], alignment, 5) self.assertEqual(result, [["s1"], ["s2"]]) def test_two_identical_seqs_returns_two_ids_clustered(self): - alignment = MultipleSeqAlignment( + alignment = MSA( [SeqRecord(Seq("AAAT"), id="s1"), SeqRecord(Seq("AAAT"), id="s2"),] ) result = AlignedSeq.kmeans_cluster_seqs_in_interval([0, 3], alignment, 1) self.assertEqual(result, [["s1", "s2"]]) def test_sequences_in_short_interval_separate_clusters(self): - alignment = MultipleSeqAlignment( + alignment = MSA( [ SeqRecord(Seq("AAAT"), id="s1"), SeqRecord(Seq("AATT"), id="s2"), @@ -90,14 +84,14 @@ def test_sequences_in_short_interval_separate_clusters(self): "This fails, probably because kmean clustering should never run with this input" ) def test_ambiguous_sequences_in_short_interval_separate_clusters(self): - alignment = MultipleSeqAlignment( + alignment = MSA( [SeqRecord(Seq("ARAT"), id="s1"), SeqRecord(Seq("WAAT"), id="s2"),] ) result = AlignedSeq.kmeans_cluster_seqs_in_interval([0, 3], alignment, 5) self.assertEqual([["s1"], ["s2"]], result) def test_two_identical_sequences_clustered_together(self): - alignment = MultipleSeqAlignment( + alignment = MSA( [ SeqRecord(Seq("AAAT"), id="s1"), SeqRecord(Seq("AAAT"), id="s2"), @@ -108,7 +102,7 @@ def test_two_identical_sequences_clustered_together(self): self.assertEqual([["s1", "s2"], ["s3"]], result) def test_all_sequences_below_min_match_len(self): - alignment = MultipleSeqAlignment( + alignment = MSA( [ SeqRecord(Seq("AA---AT"), id="s1"), SeqRecord(Seq("AA---TT"), id="s2"), @@ -139,14 +133,14 @@ def test_first_sequence_placed_in_first_cluster(self): [random.choice(bases) for _ in range(seq_len)] ) records.append(SeqRecord(Seq(rand_seq), id=f"s{i}")) - alignment = MultipleSeqAlignment(records) + alignment = MSA(records) result = AlignedSeq.kmeans_cluster_seqs_in_interval( [0, seq_len - 1], alignment, 1 ) self.assertTrue(result[0][0] == "s0") def test_one_long_one_short_sequence_separate_and_ordered_clusters(self): - alignment = MultipleSeqAlignment( + alignment = MSA( [ SeqRecord(Seq("AATTAATTATATAATAAC"), id="s1"), SeqRecord(Seq("A--------------AAT"), id="s2"), @@ -163,7 +157,7 @@ def test_one_long_one_short_sequence_separate_and_ordered_clusters(self): self.assertEqual(order_2, [["s2"], ["s1"]]) -def msas_equal(al1: MultipleSeqAlignment, al2: MultipleSeqAlignment): +def msas_equal(al1: MSA, al2: MSA): if len(al1) != len(al2): return False for i in range(len(al1)): @@ -177,7 +171,7 @@ def msas_equal(al1: MultipleSeqAlignment, al2: MultipleSeqAlignment): class TestSubAlignments(TestCase): @classmethod def setUpClass(cls): - cls.alignment = MultipleSeqAlignment( + cls.alignment = MSA( [ SeqRecord(Seq("AAAT"), id="s1"), SeqRecord(Seq("C--C"), id="s2"), @@ -188,7 +182,7 @@ def setUpClass(cls): def test_get_subalignment_sequence_order_maintained2(self): result = AlignedSeq.get_sub_alignment_by_list_id(["s1", "s3"], self.alignment) - expected = MultipleSeqAlignment([self.alignment[0], self.alignment[2]]) + expected = MSA([self.alignment[0], self.alignment[2]]) self.assertTrue(msas_equal(expected, result)) def test_get_subalignment_sequence_order_maintained(self): @@ -196,14 +190,14 @@ def test_get_subalignment_sequence_order_maintained(self): Sequences given rearranged are still output in input order """ result = AlignedSeq.get_sub_alignment_by_list_id(["s3", "s1"], self.alignment) - expected = MultipleSeqAlignment([self.alignment[0], self.alignment[2]]) + expected = MSA([self.alignment[0], self.alignment[2]]) self.assertTrue(msas_equal(expected, result)) def test_get_subalignment_with_interval(self): result = AlignedSeq.get_sub_alignment_by_list_id( ["s2", "s3"], self.alignment, [0, 2] ) - expected = MultipleSeqAlignment( + expected = MSA( [SeqRecord(Seq("C--"), id="s2"), SeqRecord(Seq("AAT"), id="s3"),] ) self.assertTrue(msas_equal(expected, result)) diff --git a/tests/test_seq_utils.py b/tests/test_seq_utils.py index 9c819a3..5804b7c 100644 --- a/tests/test_seq_utils.py +++ b/tests/test_seq_utils.py @@ -1,15 +1,14 @@ -import unittest +from unittest import TestCase -from hypothesis import given -from hypothesis.strategies import text from Bio import AlignIO from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord -from make_prg.seq_utils import remove_gaps, get_interval_seqs +from tests import make_alignment +from make_prg.seq_utils import get_interval_seqs, has_empty_sequence -class TestGetIntervals(unittest.TestCase): +class TestGetIntervals(TestCase): def test_ambiguous_bases_one_seq(self): alignment = AlignIO.MultipleSeqAlignment([SeqRecord(Seq("RWAAT"))]) result = get_interval_seqs(alignment) @@ -31,41 +30,7 @@ def test_first_sequence_in_is_first_sequence_out(self): self.assertEqual(expected, result) -class TestRemoveGaps(unittest.TestCase): - def test_empty_string_returns_empty(self): - seq = "" - - actual = remove_gaps(seq) - expected = "" - - self.assertEqual(actual, expected) - - def test_string_with_no_gaps_returns_original(self): - seq = "ACGT" - - actual = remove_gaps(seq) - expected = seq - - self.assertEqual(actual, expected) - - def test_string_with_one_gaps_returns_original_without_gap(self): - seq = "ACGT-" - - actual = remove_gaps(seq) - expected = "ACGT" - - self.assertEqual(actual, expected) - - def test_string_with_many_gaps_returns_original_without_gap(self): - seq = "A-CGT--" - - actual = remove_gaps(seq) - expected = "ACGT" - - self.assertEqual(actual, expected) - - @given(text()) - def test_all_input_space_doesnt_break(self, seq): - actual = remove_gaps(seq) - - self.assertFalse("-" in actual) +class TestEmptySeq(TestCase): + def test_sub_alignment_with_empty_sequence(self): + msa = make_alignment(["TTAGGTTT", "TTA--TTT", "GGA-TTTT"]) + self.assertTrue(has_empty_sequence(msa, [3, 4]))