From f005a93c2394b3943ef52af4f1671ea15c96ce8f Mon Sep 17 00:00:00 2001 From: Brice Letcher Date: Tue, 11 Aug 2020 12:46:35 +0100 Subject: [PATCH 1/9] Update pre commit python version to 3.8 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 64ec943..21c7bc0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,4 +3,4 @@ repos: rev: stable hooks: - id: black - language_version: python3.6 \ No newline at end of file + language_version: python3.8 From 28e4ba62c55be9ad4409ce058c8f40210c4d8e63 Mon Sep 17 00:00:00 2001 From: Brice Letcher Date: Tue, 11 Aug 2020 12:46:42 +0100 Subject: [PATCH 2/9] Terser MSA name and better consensus string code * Use generic MSA name for MSA object, defined in __init__ * Much shorter consensus string generation code, with added unit tests to make sure nothing broken --- make_prg/__init__.py | 4 ++ make_prg/exceptions.py | 2 - make_prg/io_utils.py | 5 +-- make_prg/make_prg_from_msa.py | 54 +++++++++++---------------- tests/test_make_prg_from_msa.py | 65 +++++++++++++++++++++++++-------- 5 files changed, 77 insertions(+), 53 deletions(-) delete mode 100644 make_prg/exceptions.py diff --git a/make_prg/__init__.py b/make_prg/__init__.py index 4c2e919..d71d2cb 100644 --- a/make_prg/__init__.py +++ b/make_prg/__init__.py @@ -1,3 +1,7 @@ +from Bio.AlignIO import MultipleSeqAlignment + +MSA = MultipleSeqAlignment + from pkg_resources import get_distribution try: diff --git a/make_prg/exceptions.py b/make_prg/exceptions.py deleted file mode 100644 index fbd0cef..0000000 --- a/make_prg/exceptions.py +++ /dev/null @@ -1,2 +0,0 @@ -class ClusteringError(Exception): - pass diff --git a/make_prg/io_utils.py b/make_prg/io_utils.py index a03ae6b..1eff905 100644 --- a/make_prg/io_utils.py +++ b/make_prg/io_utils.py @@ -5,12 +5,11 @@ from Bio import AlignIO +from make_prg import MSA from make_prg.prg_encoder import PrgEncoder, PRG_Ints -def load_alignment_file( - msa_file: str, alignment_format: str -) -> AlignIO.MultipleSeqAlignment: +def load_alignment_file(msa_file: str, alignment_format: str) -> MSA: logging.info("Read from MSA file %s", msa_file) if ".gz" in msa_file: logging.debug("MSA is gzipped") diff --git a/make_prg/make_prg_from_msa.py b/make_prg/make_prg_from_msa.py index d45de20..ad097fd 100644 --- a/make_prg/make_prg_from_msa.py +++ b/make_prg/make_prg_from_msa.py @@ -1,13 +1,12 @@ import logging from collections import defaultdict from typing import List -from itertools import chain import numpy as np -from Bio.AlignIO import MultipleSeqAlignment from sklearn.cluster import KMeans -from make_prg.io_utils import load_alignment_file +from make_prg import MSA +from make_prg.io_utils import load_alignment_file, MSA from make_prg.seq_utils import ( remove_duplicates, remove_gaps, @@ -18,7 +17,8 @@ class AlignedSeq(object): """ - Object based on a set of aligned sequences. Note min_match_length must be strictly greater than max_nesting + 1. + Object based on a set of aligned sequences. + Note min_match_length must be strictly greater than max_nesting + 1. """ def __init__( @@ -39,12 +39,12 @@ def __init__( self.nesting_level = nesting_level self.min_match_length = min_match_length self.site = site - self.alignment = alignment + self.alignment: MSA = alignment if self.alignment is None: self.alignment = load_alignment_file(msa_file, alignment_format) self.interval = interval - self.consensus = self.get_consensus() + self.consensus = self.get_consensus(self.alignment) self.length = len(self.consensus) (self.match_intervals, self.non_match_intervals) = self.interval_partition() self.check_nonmatch_intervals() @@ -67,36 +67,29 @@ def __init__( else: self.prg = self.get_prg() - def get_consensus(self): - """Given a set of aligment records from AlignIO, creates - a consensus string. + @classmethod + def get_consensus(cls, alignment: MSA): + """ Produces a 'consensus string' from an MSA: at each position of the + MSA, the string has a base if all aligned sequences agree, and a "*" if not. IUPAC ambiguous bases result in non-consensus and are later expanded in the prg. N results in consensus at that position unless they are all N.""" - first_string = str(self.alignment[0].seq) consensus_string = "" - for i, letter in enumerate(first_string): - consensus = True - for record in self.alignment: - if letter == "N" or record.seq[i] == "N": - if letter == "N" and record.seq[i] != "N": - letter = record.seq[i] - continue - if letter != record.seq[i] or record.seq[i] in ambiguous_bases: - consensus = False - break - if consensus and letter != "N": - consensus_string += letter - else: + for i in range(alignment.get_alignment_length()): + column = set([record.seq[i] for record in alignment]) + if "N" in column: + column.remove("N") + if len(ambiguous_bases.intersection(column)) > 0 or len(column) != 1: consensus_string += "*" - assert len(first_string) == len(consensus_string) + else: + consensus_string += column.pop() + return consensus_string def interval_partition(self): """Return a list of intervals in which we have consensus sequence longer than min_match_length, and a list of the non-match intervals left.""" - match_intervals = [] - non_match_intervals = [] + match_intervals, non_match_intervals = list(), list() match_count, match_start, non_match_start = 0, 0, 0 logging.debug("consensus: %s" % self.consensus) @@ -192,10 +185,7 @@ def check_nonmatch_intervals(self): @classmethod def kmeans_cluster_seqs_in_interval( - self, - interval: List[int], - alignment: MultipleSeqAlignment, - min_match_length: int, + self, interval: List[int], alignment: MSA, min_match_length: int, ) -> List[List[str]]: """Divide sequences in interval into subgroups of similar sequences.""" interval_alignment = alignment[:, interval[0] : interval[1] + 1] @@ -321,10 +311,10 @@ def kmeans_cluster_seqs_in_interval( @classmethod def get_sub_alignment_by_list_id( - self, id_list: List[str], alignment: MultipleSeqAlignment, interval=None + self, id_list: List[str], alignment: MSA, interval=None ): list_records = [record for record in alignment if record.id in id_list] - sub_alignment = MultipleSeqAlignment(list_records) + sub_alignment = MSA(list_records) if interval: sub_alignment = sub_alignment[:, interval[0] : interval[1] + 1] return sub_alignment diff --git a/tests/test_make_prg_from_msa.py b/tests/test_make_prg_from_msa.py index da2693a..fcb6cd7 100644 --- a/tests/test_make_prg_from_msa.py +++ b/tests/test_make_prg_from_msa.py @@ -1,19 +1,52 @@ import os import random from unittest import TestCase, mock, skip +from typing import List -from Bio.AlignIO import MultipleSeqAlignment from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord +from make_prg import MSA from make_prg.make_prg_from_msa import AlignedSeq -from make_prg.exceptions import ClusteringError from make_prg.seq_utils import standard_bases this_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) data_dir = os.path.join(this_dir, "tests", "data", "make_prg_from_msa") +def make_alignment(seqs: List[str], ids: List[str] = None) -> MSA: + if ids is None: + seqrecords = [SeqRecord(Seq(seq), id=f"s_{i}") for i, seq in enumerate(seqs)] + else: + seqrecords = [SeqRecord(Seq(seq), id=ID) for seq, ID in zip(seqs, ids)] + return MSA(seqrecords) + + +class TestConsensusString(TestCase): + def test_all_match(self): + alignment = make_alignment(["AATTA", "AATTA"]) + result = AlignedSeq.get_consensus(alignment) + self.assertEqual(result, "AATTA") + + def test_mixed_match_nonmatch(self): + alignment = make_alignment(["AAGTA", "CATTA"]) + result = AlignedSeq.get_consensus(alignment) + self.assertEqual(result, "*A*TA") + + def test_IUPACAmbiguous_nonmatch(self): + alignment = make_alignment(["RYA", "RTA"]) + result = AlignedSeq.get_consensus(alignment) + self.assertEqual(result, "**A") + + def test_N_special_treatment(self): + """ + i)A and N at pos 2 are different, but still consensus + ii)N and N at pos 0 are same, but not consensus""" + alignment = make_alignment(["NTN", "NTA"]) + result = AlignedSeq.get_consensus(alignment) + self.assertEqual(result, "*TA") + + @mock.patch.object(AlignedSeq, "check_nonmatch_intervals") @mock.patch.object(AlignedSeq, "get_prg") @mock.patch.object(AlignedSeq, "get_consensus") @@ -56,26 +89,26 @@ def test_end_in_non_match(self, get_consensus, _, __): class TestKmeansClusters(TestCase): def test_one_seq_returns_single_id(self): - alignment = MultipleSeqAlignment([SeqRecord(Seq("AAAT"), id="s1")]) + alignment = MSA([SeqRecord(Seq("AAAT"), id="s1")]) result = AlignedSeq.kmeans_cluster_seqs_in_interval([0, 3], alignment, 1) self.assertEqual(result, [["s1"]]) def test_two_seqs_one_below_min_match_len_separate_clusters(self): - alignment = MultipleSeqAlignment( + alignment = MSA( [SeqRecord(Seq("AATTTAT"), id="s1"), SeqRecord(Seq("AA---AT"), id="s2")] ) result = AlignedSeq.kmeans_cluster_seqs_in_interval([0, 5], alignment, 5) self.assertEqual(result, [["s1"], ["s2"]]) def test_two_identical_seqs_returns_two_ids_clustered(self): - alignment = MultipleSeqAlignment( + alignment = MSA( [SeqRecord(Seq("AAAT"), id="s1"), SeqRecord(Seq("AAAT"), id="s2"),] ) result = AlignedSeq.kmeans_cluster_seqs_in_interval([0, 3], alignment, 1) self.assertEqual(result, [["s1", "s2"]]) def test_sequences_in_short_interval_separate_clusters(self): - alignment = MultipleSeqAlignment( + alignment = MSA( [ SeqRecord(Seq("AAAT"), id="s1"), SeqRecord(Seq("AATT"), id="s2"), @@ -90,14 +123,14 @@ def test_sequences_in_short_interval_separate_clusters(self): "This fails, probably because kmean clustering should never run with this input" ) def test_ambiguous_sequences_in_short_interval_separate_clusters(self): - alignment = MultipleSeqAlignment( + alignment = MSA( [SeqRecord(Seq("ARAT"), id="s1"), SeqRecord(Seq("WAAT"), id="s2"),] ) result = AlignedSeq.kmeans_cluster_seqs_in_interval([0, 3], alignment, 5) self.assertEqual([["s1"], ["s2"]], result) def test_two_identical_sequences_clustered_together(self): - alignment = MultipleSeqAlignment( + alignment = MSA( [ SeqRecord(Seq("AAAT"), id="s1"), SeqRecord(Seq("AAAT"), id="s2"), @@ -108,7 +141,7 @@ def test_two_identical_sequences_clustered_together(self): self.assertEqual([["s1", "s2"], ["s3"]], result) def test_all_sequences_below_min_match_len(self): - alignment = MultipleSeqAlignment( + alignment = MSA( [ SeqRecord(Seq("AA---AT"), id="s1"), SeqRecord(Seq("AA---TT"), id="s2"), @@ -139,14 +172,14 @@ def test_first_sequence_placed_in_first_cluster(self): [random.choice(bases) for _ in range(seq_len)] ) records.append(SeqRecord(Seq(rand_seq), id=f"s{i}")) - alignment = MultipleSeqAlignment(records) + alignment = MSA(records) result = AlignedSeq.kmeans_cluster_seqs_in_interval( [0, seq_len - 1], alignment, 1 ) self.assertTrue(result[0][0] == "s0") def test_one_long_one_short_sequence_separate_and_ordered_clusters(self): - alignment = MultipleSeqAlignment( + alignment = MSA( [ SeqRecord(Seq("AATTAATTATATAATAAC"), id="s1"), SeqRecord(Seq("A--------------AAT"), id="s2"), @@ -163,7 +196,7 @@ def test_one_long_one_short_sequence_separate_and_ordered_clusters(self): self.assertEqual(order_2, [["s2"], ["s1"]]) -def msas_equal(al1: MultipleSeqAlignment, al2: MultipleSeqAlignment): +def msas_equal(al1: MSA, al2: MSA): if len(al1) != len(al2): return False for i in range(len(al1)): @@ -177,7 +210,7 @@ def msas_equal(al1: MultipleSeqAlignment, al2: MultipleSeqAlignment): class TestSubAlignments(TestCase): @classmethod def setUpClass(cls): - cls.alignment = MultipleSeqAlignment( + cls.alignment = MSA( [ SeqRecord(Seq("AAAT"), id="s1"), SeqRecord(Seq("C--C"), id="s2"), @@ -188,7 +221,7 @@ def setUpClass(cls): def test_get_subalignment_sequence_order_maintained2(self): result = AlignedSeq.get_sub_alignment_by_list_id(["s1", "s3"], self.alignment) - expected = MultipleSeqAlignment([self.alignment[0], self.alignment[2]]) + expected = MSA([self.alignment[0], self.alignment[2]]) self.assertTrue(msas_equal(expected, result)) def test_get_subalignment_sequence_order_maintained(self): @@ -196,14 +229,14 @@ def test_get_subalignment_sequence_order_maintained(self): Sequences given rearranged are still output in input order """ result = AlignedSeq.get_sub_alignment_by_list_id(["s3", "s1"], self.alignment) - expected = MultipleSeqAlignment([self.alignment[0], self.alignment[2]]) + expected = MSA([self.alignment[0], self.alignment[2]]) self.assertTrue(msas_equal(expected, result)) def test_get_subalignment_with_interval(self): result = AlignedSeq.get_sub_alignment_by_list_id( ["s2", "s3"], self.alignment, [0, 2] ) - expected = MultipleSeqAlignment( + expected = MSA( [SeqRecord(Seq("C--"), id="s2"), SeqRecord(Seq("AAT"), id="s3"),] ) self.assertTrue(msas_equal(expected, result)) From cef88182d151cf9fd0b5557bbd985fb717945037 Mon Sep 17 00:00:00 2001 From: Brice Letcher Date: Tue, 11 Aug 2020 14:15:27 +0100 Subject: [PATCH 3/9] Defer subcommand argument parsing to subcommands dir --- make_prg/__init__.py | 4 ++ make_prg/__main__.py | 75 ++-------------------------- make_prg/make_prg_from_msa.py | 5 +- make_prg/subcommands/prg_from_msa.py | 72 +++++++++++++++++++++++++- 4 files changed, 81 insertions(+), 75 deletions(-) diff --git a/make_prg/__init__.py b/make_prg/__init__.py index d71d2cb..964ce26 100644 --- a/make_prg/__init__.py +++ b/make_prg/__init__.py @@ -1,7 +1,11 @@ +# ___Constants/Aliases___ # from Bio.AlignIO import MultipleSeqAlignment MSA = MultipleSeqAlignment +NESTING_LVL = 5 +MIN_MATCH_LEN = 7 +# ___Version___ # from pkg_resources import get_distribution try: diff --git a/make_prg/__main__.py b/make_prg/__main__.py index ab84914..edc8276 100644 --- a/make_prg/__main__.py +++ b/make_prg/__main__.py @@ -1,9 +1,7 @@ import argparse -import make_prg - -NESTING_LVL = 5 -MIN_MATCH_LEN = 7 +from make_prg import __version__ +from make_prg.subcommands import prg_from_msa def main(): @@ -13,77 +11,12 @@ def main(): description="script to run make_prg subcommands", ) - parser.add_argument("--version", action="version", version=make_prg.__version__) + parser.add_argument("--version", action="version", version=__version__) subparsers = parser.add_subparsers( title="Available subcommands", help="", metavar="" ) - # _____________________________ prg_from_msa ______________________________# - subparser_prg_from_msa = subparsers.add_parser( - "prg_from_msa", - usage="make_prg prg_from_msa [options] ", - help="Make PRG from multiple sequence alignment", - ) - - subparser_prg_from_msa.add_argument( - "MSA", - action="store", - type=str, - help=( - "Input file: a multiple sequence alignment in supported alignment_format. " - "If not in aligned fasta alignment_format, use -f to input the " - "alignment_format type" - ), - ) - subparser_prg_from_msa.add_argument( - "-f", - "--alignment_format", - dest="alignment_format", - action="store", - default="fasta", - help=( - "Alignment format of MSA, must be a biopython AlignIO input " - "alignment_format. See http://biopython.org/wiki/AlignIO. Default: fasta" - ), - ) - subparser_prg_from_msa.add_argument( - "--max_nesting", - dest="max_nesting", - action="store", - type=int, - default=NESTING_LVL, - help="Maximum number of levels to use for nesting. Default: {}".format( - NESTING_LVL - ), - ) - subparser_prg_from_msa.add_argument( - "--min_match_length", - dest="min_match_length", - action="store", - type=int, - default=MIN_MATCH_LEN, - help=( - "Minimum number of consecutive characters which must be identical for a " - "match. Default: {}".format(MIN_MATCH_LEN) - ), - ) - subparser_prg_from_msa.add_argument( - "-p", "--prefix", dest="output_prefix", action="store", help="Output prefix" - ) - subparser_prg_from_msa.add_argument( - "--no_overwrite", - dest="no_overwrite", - action="store_true", - help="Do not overwrite pre-existing prg file with same name", - ) - subparser_prg_from_msa.add_argument( - "-v", - "--verbose", - dest="verbose", - action="store_true", - help="Run with high verbosity " "(debug level logging)", - ) - subparser_prg_from_msa.set_defaults(func=make_prg.subcommands.prg_from_msa.run) + prg_from_msa.register_parser(subparsers) args = parser.parse_args() diff --git a/make_prg/make_prg_from_msa.py b/make_prg/make_prg_from_msa.py index ad097fd..1d9d6f0 100644 --- a/make_prg/make_prg_from_msa.py +++ b/make_prg/make_prg_from_msa.py @@ -6,7 +6,7 @@ from sklearn.cluster import KMeans from make_prg import MSA -from make_prg.io_utils import load_alignment_file, MSA +from make_prg.io_utils import load_alignment_file from make_prg.seq_utils import ( remove_duplicates, remove_gaps, @@ -105,10 +105,9 @@ def interval_partition(self): match_string = remove_gaps( self.consensus[match_start : match_start + match_count] ) - match_len = len(match_string) logging.debug("have match string %s" % match_string) - if match_len >= self.min_match_length: + if len(match_string) >= self.min_match_length: if non_match_start < match_start: non_match_intervals.append([non_match_start, match_start - 1]) logging.debug( diff --git a/make_prg/subcommands/prg_from_msa.py b/make_prg/subcommands/prg_from_msa.py index 06b365c..faa7c32 100644 --- a/make_prg/subcommands/prg_from_msa.py +++ b/make_prg/subcommands/prg_from_msa.py @@ -2,7 +2,77 @@ import os from pathlib import Path -from make_prg import make_prg_from_msa, io_utils +from make_prg import make_prg_from_msa, io_utils, NESTING_LVL, MIN_MATCH_LEN + + +def register_parser(subparsers): + subparser_prg_from_msa = subparsers.add_parser( + "prg_from_msa", + usage="make_prg prg_from_msa [options] ", + help="Make PRG from multiple sequence alignment", + ) + + subparser_prg_from_msa.add_argument( + "MSA", + action="store", + type=str, + help=( + "Input file: a multiple sequence alignment in supported alignment_format. " + "If not in aligned fasta alignment_format, use -f to input the " + "alignment_format type" + ), + ) + subparser_prg_from_msa.add_argument( + "-f", + "--alignment_format", + dest="alignment_format", + action="store", + default="fasta", + help=( + "Alignment format of MSA, must be a biopython AlignIO input " + "alignment_format. See http://biopython.org/wiki/AlignIO. Default: fasta" + ), + ) + subparser_prg_from_msa.add_argument( + "--max_nesting", + dest="max_nesting", + action="store", + type=int, + default=NESTING_LVL, + help="Maximum number of levels to use for nesting. Default: {}".format( + NESTING_LVL + ), + ) + subparser_prg_from_msa.add_argument( + "--min_match_length", + dest="min_match_length", + action="store", + type=int, + default=MIN_MATCH_LEN, + help=( + "Minimum number of consecutive characters which must be identical for a " + "match. Default: {}".format(MIN_MATCH_LEN) + ), + ) + subparser_prg_from_msa.add_argument( + "-p", "--prefix", dest="output_prefix", action="store", help="Output prefix" + ) + subparser_prg_from_msa.add_argument( + "--no_overwrite", + dest="no_overwrite", + action="store_true", + help="Do not overwrite pre-existing prg file with same name", + ) + subparser_prg_from_msa.add_argument( + "-v", + "--verbose", + dest="verbose", + action="store_true", + help="Run with high verbosity " "(debug level logging)", + ) + subparser_prg_from_msa.set_defaults(func=run) + + return subparser_prg_from_msa def run(options): From b56f18f214444bcab6683b2814e0c62750dda19d Mon Sep 17 00:00:00 2001 From: Brice Letcher Date: Tue, 11 Aug 2020 15:36:13 +0100 Subject: [PATCH 4/9] Refactor interval partitioning * Move functions to test interval consistency out of class * Unit test them --- make_prg/__init__.py | 2 +- make_prg/make_prg_from_msa.py | 65 +++++------------ make_prg/seq_utils.py | 70 ------------------ make_prg/utils.py | 125 ++++++++++++++++++++++++++++++++ tests/test_make_prg_from_msa.py | 69 ++++++++++++++++-- tests/test_seq_utils.py | 2 +- 6 files changed, 209 insertions(+), 124 deletions(-) delete mode 100644 make_prg/seq_utils.py create mode 100644 make_prg/utils.py diff --git a/make_prg/__init__.py b/make_prg/__init__.py index 964ce26..a3841e5 100644 --- a/make_prg/__init__.py +++ b/make_prg/__init__.py @@ -13,6 +13,6 @@ except: __version__ = "local" -__all__ = ["make_prg_from_msa", "subcommands", "io_utils", "seq_utils"] +__all__ = ["make_prg_from_msa", "subcommands", "io_utils", "utils"] from make_prg import * diff --git a/make_prg/make_prg_from_msa.py b/make_prg/make_prg_from_msa.py index 1d9d6f0..4c56f1c 100644 --- a/make_prg/make_prg_from_msa.py +++ b/make_prg/make_prg_from_msa.py @@ -7,11 +7,13 @@ from make_prg import MSA from make_prg.io_utils import load_alignment_file -from make_prg.seq_utils import ( +from make_prg.utils import ( + ambiguous_bases, remove_duplicates, remove_gaps, get_interval_seqs, - ambiguous_bases, + enforce_multisequence_nonmatch_intervals, + enforce_alignment_interval_bijection, ) @@ -46,11 +48,11 @@ def __init__( self.interval = interval self.consensus = self.get_consensus(self.alignment) self.length = len(self.consensus) - (self.match_intervals, self.non_match_intervals) = self.interval_partition() - self.check_nonmatch_intervals() - self.all_intervals = self.match_intervals + self.non_match_intervals - logging.info("Non match intervals: %s", self.non_match_intervals) - self.all_intervals.sort() + ( + self.match_intervals, + self.non_match_intervals, + self.all_intervals, + ) = self.partition_alignment_into_intervals() # properties for stats self.subAlignedSeqs = {} @@ -85,7 +87,7 @@ def get_consensus(cls, alignment: MSA): return consensus_string - def interval_partition(self): + def partition_alignment_into_intervals(self): """Return a list of intervals in which we have consensus sequence longer than min_match_length, and a list of the non-match intervals left.""" @@ -143,44 +145,17 @@ def interval_partition(self): non_match_intervals.append([non_match_start, end]) logging.debug(f"add non-match interval [{non_match_start},{end}]") - # check all stretches of consensus are in an interval, and intervals don't overlap - for i in range(self.length): - count_match = 0 - for interval in match_intervals: - if interval[0] <= i <= interval[1]: - count_match += 1 - count_non_match = 0 - for interval in non_match_intervals: - if interval[0] <= i <= interval[1]: - count_non_match += 1 - - assert count_match | count_non_match, ( - "Failed to correctly identify match intervals: position %d " - "appeared in both/neither match and non-match intervals" % i - ) - assert count_match + count_non_match == 1, ( - "Failed to correctly identify match intervals: position " - "%d appeared in %d intervals" % (i, count_match + count_non_match) - ) + enforce_alignment_interval_bijection( + match_intervals, non_match_intervals, self.length + ) - return match_intervals, non_match_intervals - - def check_nonmatch_intervals(self): - """ - Goes through non-match intervals and makes sure there is more than one sequence there, else makes it a match - interval. - Example reasons for such a conversion to occur: - - 'N' in a sequence causes it to be filtered out, and left with a single useable sequence - - '-' in sequences causes them to appear different, but they are the same - """ - for i in reversed(range(len(self.non_match_intervals))): - interval = self.non_match_intervals[i] - interval_alignment = self.alignment[:, interval[0] : interval[1] + 1] - interval_seqs = get_interval_seqs(interval_alignment) - if len(interval_seqs) < 2: - self.match_intervals.append(self.non_match_intervals[i]) - self.non_match_intervals.pop(i) - self.match_intervals.sort() + logging.info("Non match intervals: %s", non_match_intervals) + enforce_multisequence_nonmatch_intervals( + match_intervals, non_match_intervals, self.alignment + ) + all_intervals = match_intervals + non_match_intervals + all_intervals.sort() + return match_intervals, non_match_intervals, all_intervals @classmethod def kmeans_cluster_seqs_in_interval( diff --git a/make_prg/seq_utils.py b/make_prg/seq_utils.py deleted file mode 100644 index 1b700f1..0000000 --- a/make_prg/seq_utils.py +++ /dev/null @@ -1,70 +0,0 @@ -import logging -from typing import Generator, Sequence -import itertools - -from Bio import AlignIO - - -def remove_duplicates(seqs: Sequence) -> Generator: - seen = set() - for x in seqs: - if x in seen: - continue - seen.add(x) - yield x - - -def remove_gaps(sequence: str) -> str: - return sequence.replace("-", "") - - -iupac = { - "R": "GA", - "Y": "TC", - "K": "GT", - "M": "AC", - "S": "GC", - "W": "AT", - "A": "A", - "C": "C", - "G": "G", - "T": "T", -} -allowed_bases = set(iupac.keys()) -standard_bases = {"A", "C", "G", "T"} -ambiguous_bases = allowed_bases.difference(standard_bases) - - -def get_interval_seqs(interval_alignment: AlignIO.MultipleSeqAlignment): - """ - Replace - with nothing, remove seqs containing N or other non-allowed letters - and duplicate sequences containing RYKMSW, replacing with AGCT alternatives - - The sequences are deliberately returned in the order they are received - """ - gapless_seqs = [str(record.seq.ungap("-")) for record in interval_alignment] - - callback_seqs, expanded_seqs = [], [] - expanded_set = set() - for seq in remove_duplicates(gapless_seqs): - if len(expanded_set) == 0: - callback_seqs.append(seq) - if not set(seq).issubset(allowed_bases): - continue - alternatives = [iupac[base] for base in seq] - for tuple_product in itertools.product(*alternatives): - expanded_str = "".join(tuple_product) - if expanded_str not in expanded_set: - expanded_set.add(expanded_str) - expanded_seqs.append(expanded_str) - - if len(expanded_set) == 0: - logging.warning( - "WARNING: Every sequence must have contained an N in this slice - redo sequence curation because this is nonsense" - ) - logging.warning(f'Sequences were: {" ".join(callback_seqs)}') - logging.warning( - "Using these sequences anyway, and should be ignored downstream" - ) - return callback_seqs - return expanded_seqs diff --git a/make_prg/utils.py b/make_prg/utils.py new file mode 100644 index 0000000..da9342a --- /dev/null +++ b/make_prg/utils.py @@ -0,0 +1,125 @@ +import logging +from typing import Generator, Sequence +import itertools + +from Bio import AlignIO + +from make_prg import MSA + + +class PartitioningError(Exception): + pass + + +def remove_duplicates(seqs: Sequence) -> Generator: + seen = set() + for x in seqs: + if x in seen: + continue + seen.add(x) + yield x + + +def remove_gaps(sequence: str) -> str: + return sequence.replace("-", "") + + +iupac = { + "R": "GA", + "Y": "TC", + "K": "GT", + "M": "AC", + "S": "GC", + "W": "AT", + "A": "A", + "C": "C", + "G": "G", + "T": "T", +} +allowed_bases = set(iupac.keys()) +standard_bases = {"A", "C", "G", "T"} +ambiguous_bases = allowed_bases.difference(standard_bases) + + +def get_interval_seqs(interval_alignment: AlignIO.MultipleSeqAlignment): + """ + Replace - with nothing, remove seqs containing N or other non-allowed letters + and duplicate sequences containing RYKMSW, replacing with AGCT alternatives + + The sequences are deliberately returned in the order they are received + """ + gapless_seqs = [str(record.seq.ungap("-")) for record in interval_alignment] + + callback_seqs, expanded_seqs = [], [] + expanded_set = set() + for seq in remove_duplicates(gapless_seqs): + if len(expanded_set) == 0: + callback_seqs.append(seq) + if not set(seq).issubset(allowed_bases): + continue + alternatives = [iupac[base] for base in seq] + for tuple_product in itertools.product(*alternatives): + expanded_str = "".join(tuple_product) + if expanded_str not in expanded_set: + expanded_set.add(expanded_str) + expanded_seqs.append(expanded_str) + + if len(expanded_set) == 0: + logging.warning( + "WARNING: Every sequence must have contained an N in this slice - redo sequence curation because this is nonsense" + ) + logging.warning(f'Sequences were: {" ".join(callback_seqs)}') + logging.warning( + "Using these sequences anyway, and should be ignored downstream" + ) + return callback_seqs + return expanded_seqs + + +def enforce_multisequence_nonmatch_intervals( + match_intervals, non_match_intervals, alignment: MSA +): + """ + Goes through non-match intervals and makes sure there is more than one sequence there, else makes it a match + interval. + Example reasons for such a conversion to occur: + - 'N' in a sequence causes it to be filtered out, and left with a single useable sequence + - '-' in sequences causes them to appear different, but they are the same + """ + for i in reversed(range(len(non_match_intervals))): + interval = non_match_intervals[i] + interval_alignment = alignment[:, interval[0] : interval[1] + 1] + interval_seqs = get_interval_seqs(interval_alignment) + if len(interval_seqs) < 2: + match_intervals.append(non_match_intervals[i]) + non_match_intervals.pop(i) + match_intervals.sort() + + +def enforce_alignment_interval_bijection( + match_intervals, non_match_intervals, alignment_length: int +): + """ + Check each position in an alignment is in one, and one only, (match or non_match) interval + """ + for i in range(alignment_length): + count_match = 0 + for interval in match_intervals: + if interval[0] <= i <= interval[1]: + count_match += 1 + count_non_match = 0 + for interval in non_match_intervals: + if interval[0] <= i <= interval[1]: + count_non_match += 1 + + if count_match > 1 or count_non_match > 1: + raise PartitioningError( + f"Failed interval partitioning: position {i}" + " appears in more than one interval" + ) + if not count_match ^ count_non_match: # test fails if they are the same integer + msg = ["neither", "nor"] if count_match == 0 else ["both", "and"] + raise PartitioningError( + "Failed interval partitioning: alignment position %d" + "classified as %s match %s non-match " % (i, *msg) + ) diff --git a/tests/test_make_prg_from_msa.py b/tests/test_make_prg_from_msa.py index fcb6cd7..0b81c4d 100644 --- a/tests/test_make_prg_from_msa.py +++ b/tests/test_make_prg_from_msa.py @@ -8,7 +8,12 @@ from make_prg import MSA from make_prg.make_prg_from_msa import AlignedSeq -from make_prg.seq_utils import standard_bases +from make_prg.utils import ( + standard_bases, + enforce_multisequence_nonmatch_intervals, + enforce_alignment_interval_bijection, + PartitioningError, +) this_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) data_dir = os.path.join(this_dir, "tests", "data", "make_prg_from_msa") @@ -33,6 +38,11 @@ def test_mixed_match_nonmatch(self): result = AlignedSeq.get_consensus(alignment) self.assertEqual(result, "*A*TA") + def test_indel_nonmatch(self): + alignment = make_alignment(["AAAA", "A--A"]) + result = AlignedSeq.get_consensus(alignment) + self.assertEqual(result, "A**A") + def test_IUPACAmbiguous_nonmatch(self): alignment = make_alignment(["RYA", "RTA"]) result = AlignedSeq.get_consensus(alignment) @@ -47,42 +57,87 @@ def test_N_special_treatment(self): self.assertEqual(result, "*TA") -@mock.patch.object(AlignedSeq, "check_nonmatch_intervals") +class TestIntervalConsistency(TestCase): + def test_nonmatch_interval_switching_indels(self): + """Because the sequences are the same, despite different alignment""" + alignment = make_alignment(["A---A", "A-A--"]) + match_intervals = [] + non_match_intervals = [[0, 5]] + enforce_multisequence_nonmatch_intervals( + match_intervals, non_match_intervals, alignment + ) + self.assertEqual(match_intervals, [[0, 5]]) + self.assertEqual(non_match_intervals, []) + + def test_nonmatch_interval_switching_Ns(self): + """'N's make sequences get removed""" + alignment = make_alignment(["ANAAA", "ATAAT"]) + match_intervals = [] + non_match_intervals = [[0, 5]] + enforce_multisequence_nonmatch_intervals( + match_intervals, non_match_intervals, alignment + ) + self.assertEqual(match_intervals, [[0, 5]]) + self.assertEqual(non_match_intervals, []) + + def test_position_in_several_intervals_fails(self): + match_intervals = [[0, 1], [1, 2]] + with self.assertRaises(PartitioningError): + enforce_alignment_interval_bijection(match_intervals, [], 3) + + def test_position_in_no_interval_fails(self): + match_intervals = [[0, 1]] + with self.assertRaises(PartitioningError): + enforce_alignment_interval_bijection(match_intervals, [], 3) + + def test_position_in_match_and_nonmatch_intervals_fails(self): + match_intervals = [[0, 2]] + nmatch_intervals = [[2, 3]] + with self.assertRaises(PartitioningError): + enforce_alignment_interval_bijection(match_intervals, nmatch_intervals, 4) + + def test_bijection_respected_passes(self): + match_intervals = [[0, 2], [5, 10]] + nmatch_intervals = [[3, 4]] + enforce_alignment_interval_bijection(match_intervals, nmatch_intervals, 11) + + +@mock.patch("make_prg.make_prg_from_msa.enforce_multisequence_nonmatch_intervals") @mock.patch.object(AlignedSeq, "get_prg") @mock.patch.object(AlignedSeq, "get_consensus") class TestIntervalPartitioning(TestCase): def test_all_non_match(self, get_consensus, _, __): get_consensus.return_value = "******" tester = AlignedSeq("_", alignment="_", min_match_length=3) - match, non_match = tester.interval_partition() + match, non_match, _ = tester.partition_alignment_into_intervals() self.assertEqual(match, []) self.assertEqual(non_match, [[0, 5]]) def test_all_match(self, get_consensus, _, __): get_consensus.return_value = "ATATAAA" tester = AlignedSeq("_", alignment="_", min_match_length=3) - match, non_match = tester.interval_partition() + match, non_match, _ = tester.partition_alignment_into_intervals() self.assertEqual(match, [[0, 6]]) self.assertEqual(non_match, []) def test_short_match_counted_as_non_match(self, get_consensus, _, __): get_consensus.return_value = "AT***" tester = AlignedSeq("_", alignment="_", min_match_length=3) - match, non_match = tester.interval_partition() + match, non_match, _ = tester.partition_alignment_into_intervals() self.assertEqual(match, []) self.assertEqual(non_match, [[0, 4]]) def test_match_non_match_match(self, get_consensus, _, __): get_consensus.return_value = "ATT**AAAC" tester = AlignedSeq("_", alignment="_", min_match_length=3) - match, non_match = tester.interval_partition() + match, non_match, _ = tester.partition_alignment_into_intervals() self.assertEqual(match, [[0, 2], [5, 8]]) self.assertEqual(non_match, [[3, 4]]) def test_end_in_non_match(self, get_consensus, _, __): get_consensus.return_value = "**ATT**AAA*C" tester = AlignedSeq("_", alignment="_", min_match_length=3) - match, non_match = tester.interval_partition() + match, non_match, _ = tester.partition_alignment_into_intervals() self.assertEqual(match, [[2, 4], [7, 9]]) self.assertEqual(non_match, [[0, 1], [5, 6], [10, 11]]) diff --git a/tests/test_seq_utils.py b/tests/test_seq_utils.py index 9c819a3..0f9e2c7 100644 --- a/tests/test_seq_utils.py +++ b/tests/test_seq_utils.py @@ -6,7 +6,7 @@ from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord -from make_prg.seq_utils import remove_gaps, get_interval_seqs +from make_prg.utils import remove_gaps, get_interval_seqs class TestGetIntervals(unittest.TestCase): From 771722b33a10da16d5a0ec79068b0d509344c8e1 Mon Sep 17 00:00:00 2001 From: Brice Letcher Date: Tue, 11 Aug 2020 18:38:24 +0100 Subject: [PATCH 5/9] Alignment column of all '-' no longer a match --- make_prg/make_prg_from_msa.py | 55 +++++++++++++-------------------- make_prg/utils.py | 4 --- tests/test_make_prg_from_msa.py | 5 +++ tests/test_seq_utils.py | 44 +------------------------- 4 files changed, 28 insertions(+), 80 deletions(-) diff --git a/make_prg/make_prg_from_msa.py b/make_prg/make_prg_from_msa.py index 4c56f1c..8bbe79b 100644 --- a/make_prg/make_prg_from_msa.py +++ b/make_prg/make_prg_from_msa.py @@ -10,7 +10,6 @@ from make_prg.utils import ( ambiguous_bases, remove_duplicates, - remove_gaps, get_interval_seqs, enforce_multisequence_nonmatch_intervals, enforce_alignment_interval_bijection, @@ -78,9 +77,12 @@ def get_consensus(cls, alignment: MSA): consensus_string = "" for i in range(alignment.get_alignment_length()): column = set([record.seq[i] for record in alignment]) - if "N" in column: - column.remove("N") - if len(ambiguous_bases.intersection(column)) > 0 or len(column) != 1: + column = column.difference({"N"}) + if ( + len(ambiguous_bases.intersection(column)) > 0 + or len(column) != 1 + or column == {"-"} + ): consensus_string += "*" else: consensus_string += column.pop() @@ -96,20 +98,20 @@ def partition_alignment_into_intervals(self): logging.debug("consensus: %s" % self.consensus) for i in range(self.length): - letter = self.consensus[i] - if letter != "*": + if self.consensus[i] != "*": # In a match region. if match_count == 0: match_start = i match_count += 1 - elif match_count > 0: - # Have reached a non-match. Check if previous match string is long enough to add to match_regions - match_string = remove_gaps( - self.consensus[match_start : match_start + match_count] + else: + if match_count == 0: + continue + logging.debug( + "have match string %s" + % self.consensus[match_start : match_start + match_count] ) - logging.debug("have match string %s" % match_string) - if len(match_string) >= self.min_match_length: + if (match_count - match_start + 1) >= self.min_match_length: if non_match_start < match_start: non_match_intervals.append([non_match_start, match_start - 1]) logging.debug( @@ -118,9 +120,8 @@ def partition_alignment_into_intervals(self): end = match_start + match_count - 1 match_intervals.append([match_start, end]) logging.debug(f"add match interval [{match_start},{end}]") - non_match_start = i + non_match_start = match_start = i match_count = 0 - match_start = non_match_start end = self.length - 1 if self.length < self.min_match_length: @@ -318,24 +319,21 @@ def get_prg(self): if (self.nesting_level == self.max_nesting) or ( interval[1] - interval[0] <= self.min_match_length ): - # Have reached max nesting level, just add all variants in interval. logging.debug( "Have reached max nesting level or have a small variant site, so add all variant " "sequences in interval." ) sub_alignment = self.alignment[:, interval[0] : interval[1] + 1] - logging.debug( - "Variant seqs found: %s" - % list( + if logging.getLogger().isEnabledFor(logging.DEBUG): + seqs = list( remove_duplicates( [str(record.seq) for record in sub_alignment] ) ) - ) + logging.debug(f"Variant seqs found: {seqs}") variant_prgs = get_interval_seqs(sub_alignment) logging.debug("Which is equivalent to: %s" % variant_prgs) else: - # divide sequences into subgroups and define prg for each subgroup. logging.debug( "Divide sequences into subgroups and define prg for each subgroup." ) @@ -385,11 +383,7 @@ def get_prg(self): self.site = sub_aligned_seq.site if recur: - # logging.debug("None not in snp_scores - try to add sub__aligned_seq to list in - # dictionary") self.subAlignedSeqs[interval[0]].append(sub_aligned_seq) - # logging.debug("Length of subAlignedSeqs[%d] is %d", interval[0], - # len(self.subAlignedSeqs[interval[0]])) assert num_clusters == len(variant_prgs), ( "I don't seem to have a sub-prg sequence for all parts of the partition - there are %d " "classes in partition, and %d variant seqs" @@ -401,18 +395,13 @@ def get_prg(self): list(remove_duplicates(variant_prgs)) ), "have repeat variant seqs" - # Add the variant seqs to the prg - prg += "%s%d%s" % ( - self.delim_char, - site_num, - self.delim_char, - ) # considered making it so start of prg was not delim_char, - # but that would defeat the point if it + # Add the variant seqs to the prg. + prg += f"{self.delim_char}{site_num}{self.delim_char}" while len(variant_prgs) > 1: prg += variant_prgs.pop(0) - prg += "%s%d%s" % (self.delim_char, site_num + 1, self.delim_char) + prg += f"{self.delim_char}{site_num + 1}{self.delim_char}" prg += variant_prgs.pop() - prg += "%s%d%s" % (self.delim_char, site_num, self.delim_char) + prg += f"{self.delim_char}{site_num}{self.delim_char}" return prg diff --git a/make_prg/utils.py b/make_prg/utils.py index da9342a..b1f9141 100644 --- a/make_prg/utils.py +++ b/make_prg/utils.py @@ -20,10 +20,6 @@ def remove_duplicates(seqs: Sequence) -> Generator: yield x -def remove_gaps(sequence: str) -> str: - return sequence.replace("-", "") - - iupac = { "R": "GA", "Y": "TC", diff --git a/tests/test_make_prg_from_msa.py b/tests/test_make_prg_from_msa.py index 0b81c4d..7d150c9 100644 --- a/tests/test_make_prg_from_msa.py +++ b/tests/test_make_prg_from_msa.py @@ -56,6 +56,11 @@ def test_N_special_treatment(self): result = AlignedSeq.get_consensus(alignment) self.assertEqual(result, "*TA") + def test_all_gap_nonmatch(self): + alignment = make_alignment(["A--A", "A--A"]) + result = AlignedSeq.get_consensus(alignment) + self.assertEqual(result, "A**A") + class TestIntervalConsistency(TestCase): def test_nonmatch_interval_switching_indels(self): diff --git a/tests/test_seq_utils.py b/tests/test_seq_utils.py index 0f9e2c7..3282952 100644 --- a/tests/test_seq_utils.py +++ b/tests/test_seq_utils.py @@ -1,12 +1,10 @@ import unittest -from hypothesis import given -from hypothesis.strategies import text from Bio import AlignIO from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord -from make_prg.utils import remove_gaps, get_interval_seqs +from make_prg.utils import get_interval_seqs class TestGetIntervals(unittest.TestCase): @@ -29,43 +27,3 @@ def test_first_sequence_in_is_first_sequence_out(self): result = get_interval_seqs(alignment) expected = ["TTTT", "AAAA", "CCC"] self.assertEqual(expected, result) - - -class TestRemoveGaps(unittest.TestCase): - def test_empty_string_returns_empty(self): - seq = "" - - actual = remove_gaps(seq) - expected = "" - - self.assertEqual(actual, expected) - - def test_string_with_no_gaps_returns_original(self): - seq = "ACGT" - - actual = remove_gaps(seq) - expected = seq - - self.assertEqual(actual, expected) - - def test_string_with_one_gaps_returns_original_without_gap(self): - seq = "ACGT-" - - actual = remove_gaps(seq) - expected = "ACGT" - - self.assertEqual(actual, expected) - - def test_string_with_many_gaps_returns_original_without_gap(self): - seq = "A-CGT--" - - actual = remove_gaps(seq) - expected = "ACGT" - - self.assertEqual(actual, expected) - - @given(text()) - def test_all_input_space_doesnt_break(self, seq): - actual = remove_gaps(seq) - - self.assertFalse("-" in actual) From 8cd3fc54420e2e7c5638bdd7bb32d7af1ba244fe Mon Sep 17 00:00:00 2001 From: Brice Letcher Date: Wed, 12 Aug 2020 10:40:50 +0100 Subject: [PATCH 6/9] Refactored interval partitioning (1) * New object-oriented interval partitioning code * Add unit test for small consensus string which led to a bugfix --- make_prg/__init__.py | 2 +- make_prg/interval_partition.py | 161 ++++++++++++++++++++++++++++++++ make_prg/make_prg_from_msa.py | 21 ++--- make_prg/seq_utils.py | 66 +++++++++++++ make_prg/utils.py | 121 ------------------------ tests/test_make_prg_from_msa.py | 25 +++-- tests/test_seq_utils.py | 2 +- 7 files changed, 254 insertions(+), 144 deletions(-) create mode 100644 make_prg/interval_partition.py create mode 100644 make_prg/seq_utils.py delete mode 100644 make_prg/utils.py diff --git a/make_prg/__init__.py b/make_prg/__init__.py index a3841e5..964ce26 100644 --- a/make_prg/__init__.py +++ b/make_prg/__init__.py @@ -13,6 +13,6 @@ except: __version__ = "local" -__all__ = ["make_prg_from_msa", "subcommands", "io_utils", "utils"] +__all__ = ["make_prg_from_msa", "subcommands", "io_utils", "seq_utils"] from make_prg import * diff --git a/make_prg/interval_partition.py b/make_prg/interval_partition.py new file mode 100644 index 0000000..6e1cd52 --- /dev/null +++ b/make_prg/interval_partition.py @@ -0,0 +1,161 @@ +""" +Code responsible for converting a consensus string into a set of disjoint +match/non_match intervals. +""" +from enum import Enum, auto +from typing import List, Tuple + +from make_prg import MSA + +from make_prg.seq_utils import get_interval_seqs + + +class IntervalType(Enum): + Match = auto() + NonMatch = auto() + + +def get_type(letter: str) -> IntervalType: + if letter == "*": + return IntervalType.NonMatch + else: + return IntervalType.Match + + +def is_type(letter: str, interval_type: IntervalType) -> bool: + if get_type(letter) is interval_type: + return True + else: + return False + + +class Interval: + def __init__(self, it_type: IntervalType, start: int, stop: int = None): + self.type = it_type + self.start = start + self.stop = stop if stop is not None else start + + def modify_by(self, left_delta: int, right_delta: int): + self.start += left_delta + self.stop += right_delta + + def __len__(self) -> int: + return self.stop - self.start + 1 + + def __lt__(self, other: "Interval") -> bool: + return self.start < other.start + + +Intervals = List[Interval] + + +class IntervalPartitioner: + """Produces a list of intervals in which we have + consensus sequence longer than min_match_length, and + a list of the non-match intervals left.""" + + def __init__(self, consensus_string: str, min_match_length: int, alignment: MSA): + self._match_intervals: Intervals = list() + self._non_match_intervals: Intervals = list() + self.mml = min_match_length + + cur_interval = self._new_interval(consensus_string[0], 0) + + for i, letter in enumerate(consensus_string[1:], start=1): + if is_type(letter, cur_interval.type): + cur_interval.modify_by(0, 1) # simple interval extension + else: + self._add_interval(cur_interval, alignment) + cur_interval = self._new_interval(letter, i) + self._add_interval(cur_interval, alignment) + + def get_intervals(self) -> Tuple[Intervals, Intervals, Intervals]: + return ( + sorted(self._match_intervals), + sorted(self._non_match_intervals), + sorted(self._match_intervals + self._non_match_intervals), + ) + + def _new_interval(self, letter: str, start_pos: int) -> Interval: + return Interval(get_type(letter), start_pos) + + def _append(self, interval: Interval): + if interval.type is IntervalType.Match: + self._match_intervals.append(interval) + else: + self._non_match_intervals.append(interval) + + def _pop(self, it_type: IntervalType) -> Interval: + if it_type is IntervalType.Match: + return self._match_intervals.pop() + else: + return self._non_match_intervals.pop() + + def _add_interval(self, interval: Interval, alignment: MSA): + if interval.type is IntervalType.Match: + if len(interval) < self.mml: + try: + last_non_match = self._pop(IntervalType.NonMatch) + last_non_match.modify_by(0, len(interval)) + except IndexError: + last_non_match = Interval( + interval.start, interval.stop, IntervalType.NonMatch + ) + self.append(last_non_match) + return + else: + pass + self._append(interval) + + +def enforce_multisequence_nonmatch_intervals( + match_intervals, non_match_intervals, alignment: MSA +): + """ + Goes through non-match intervals and makes sure there is more than one sequence there, else makes it a match + interval. + Example reasons for such a conversion to occur: + - 'N' in a sequence causes it to be filtered out, and left with a single useable sequence + - '-' in sequences causes them to appear different, but they are the same + """ + for i in reversed(range(len(non_match_intervals))): + interval = non_match_intervals[i] + interval_alignment = alignment[:, interval[0] : interval[1] + 1] + interval_seqs = get_interval_seqs(interval_alignment) + if len(interval_seqs) < 2: + match_intervals.append(non_match_intervals[i]) + non_match_intervals.pop(i) + match_intervals.sort() + + +class PartitioningError(Exception): + pass + + +def enforce_alignment_interval_bijection( + match_intervals, non_match_intervals, alignment_length: int +): + """ + Check each position in an alignment is in one, and one only, (match or non_match) interval + """ + for i in range(alignment_length): + count_match = 0 + for interval in match_intervals: + if interval[0] <= i <= interval[1]: + count_match += 1 + count_non_match = 0 + for interval in non_match_intervals: + if interval[0] <= i <= interval[1]: + count_non_match += 1 + + if count_match > 1 or count_non_match > 1: + raise PartitioningError( + f"Failed interval partitioning: position {i}" + " appears in more than one interval" + ) + if not count_match ^ count_non_match: # test fails if they are the same integer + msg = ["neither", "nor"] if count_match == 0 else ["both", "and"] + raise PartitioningError( + "Failed interval partitioning: alignment position %d" + "classified as %s match %s non-match " % (i, *msg) + ) diff --git a/make_prg/make_prg_from_msa.py b/make_prg/make_prg_from_msa.py index 8bbe79b..648662d 100644 --- a/make_prg/make_prg_from_msa.py +++ b/make_prg/make_prg_from_msa.py @@ -7,10 +7,12 @@ from make_prg import MSA from make_prg.io_utils import load_alignment_file -from make_prg.utils import ( +from make_prg.seq_utils import ( ambiguous_bases, remove_duplicates, get_interval_seqs, +) +from make_prg.interval_partition import ( enforce_multisequence_nonmatch_intervals, enforce_alignment_interval_bijection, ) @@ -111,7 +113,7 @@ def partition_alignment_into_intervals(self): % self.consensus[match_start : match_start + match_count] ) - if (match_count - match_start + 1) >= self.min_match_length: + if match_count >= self.min_match_length: if non_match_start < match_start: non_match_intervals.append([non_match_start, match_start - 1]) logging.debug( @@ -126,7 +128,7 @@ def partition_alignment_into_intervals(self): end = self.length - 1 if self.length < self.min_match_length: # Special case: a short sequence can still get classified as a match interval - added_interval = "match" if "*" in self.consensus else "non_match" + added_interval = "non_match" if "*" in self.consensus else "match" if added_interval == "match": match_intervals.append([0, end]) else: @@ -296,8 +298,6 @@ def get_sub_alignment_by_list_id( def get_prg(self): prg = "" - # last_char = None - # skip_char = False for interval in self.all_intervals: if interval in self.match_intervals: @@ -305,7 +305,7 @@ def get_prg(self): # thus still process all of them, to get the one with no 'N'. sub_alignment = self.alignment[:, interval[0] : interval[1] + 1] seqs = get_interval_seqs(sub_alignment) - assert 0 < len(seqs) <= 1, "Got >1 filtered sequences in match interval" + assert len(seqs) == 1, "Got >1 filtered sequences in match interval" seq = seqs[0] prg += seq @@ -324,15 +324,8 @@ def get_prg(self): "sequences in interval." ) sub_alignment = self.alignment[:, interval[0] : interval[1] + 1] - if logging.getLogger().isEnabledFor(logging.DEBUG): - seqs = list( - remove_duplicates( - [str(record.seq) for record in sub_alignment] - ) - ) - logging.debug(f"Variant seqs found: {seqs}") variant_prgs = get_interval_seqs(sub_alignment) - logging.debug("Which is equivalent to: %s" % variant_prgs) + logging.debug(f"Variant seqs found: {variant_prgs}") else: logging.debug( "Divide sequences into subgroups and define prg for each subgroup." diff --git a/make_prg/seq_utils.py b/make_prg/seq_utils.py new file mode 100644 index 0000000..b08b2bf --- /dev/null +++ b/make_prg/seq_utils.py @@ -0,0 +1,66 @@ +import logging +from typing import Generator, Sequence +import itertools + +from Bio import AlignIO + + +def remove_duplicates(seqs: Sequence) -> Generator: + seen = set() + for x in seqs: + if x in seen: + continue + seen.add(x) + yield x + + +iupac = { + "R": "GA", + "Y": "TC", + "K": "GT", + "M": "AC", + "S": "GC", + "W": "AT", + "A": "A", + "C": "C", + "G": "G", + "T": "T", +} +allowed_bases = set(iupac.keys()) +standard_bases = {"A", "C", "G", "T"} +ambiguous_bases = allowed_bases.difference(standard_bases) + + +def get_interval_seqs(interval_alignment: AlignIO.MultipleSeqAlignment): + """ + Replace - with nothing, remove seqs containing N or other non-allowed letters + and duplicate sequences containing RYKMSW, replacing with AGCT alternatives + + The sequences are deliberately returned in the order they are received + """ + gapless_seqs = [str(record.seq.ungap("-")) for record in interval_alignment] + + callback_seqs, expanded_seqs = [], [] + expanded_set = set() + for seq in remove_duplicates(gapless_seqs): + if len(expanded_set) == 0: + callback_seqs.append(seq) + if not set(seq).issubset(allowed_bases): + continue + alternatives = [iupac[base] for base in seq] + for tuple_product in itertools.product(*alternatives): + expanded_str = "".join(tuple_product) + if expanded_str not in expanded_set: + expanded_set.add(expanded_str) + expanded_seqs.append(expanded_str) + + if len(expanded_set) == 0: + logging.warning( + "WARNING: Every sequence must have contained an N in this slice - redo sequence curation because this is nonsense" + ) + logging.warning(f'Sequences were: {" ".join(callback_seqs)}') + logging.warning( + "Using these sequences anyway, and should be ignored downstream" + ) + return callback_seqs + return expanded_seqs diff --git a/make_prg/utils.py b/make_prg/utils.py deleted file mode 100644 index b1f9141..0000000 --- a/make_prg/utils.py +++ /dev/null @@ -1,121 +0,0 @@ -import logging -from typing import Generator, Sequence -import itertools - -from Bio import AlignIO - -from make_prg import MSA - - -class PartitioningError(Exception): - pass - - -def remove_duplicates(seqs: Sequence) -> Generator: - seen = set() - for x in seqs: - if x in seen: - continue - seen.add(x) - yield x - - -iupac = { - "R": "GA", - "Y": "TC", - "K": "GT", - "M": "AC", - "S": "GC", - "W": "AT", - "A": "A", - "C": "C", - "G": "G", - "T": "T", -} -allowed_bases = set(iupac.keys()) -standard_bases = {"A", "C", "G", "T"} -ambiguous_bases = allowed_bases.difference(standard_bases) - - -def get_interval_seqs(interval_alignment: AlignIO.MultipleSeqAlignment): - """ - Replace - with nothing, remove seqs containing N or other non-allowed letters - and duplicate sequences containing RYKMSW, replacing with AGCT alternatives - - The sequences are deliberately returned in the order they are received - """ - gapless_seqs = [str(record.seq.ungap("-")) for record in interval_alignment] - - callback_seqs, expanded_seqs = [], [] - expanded_set = set() - for seq in remove_duplicates(gapless_seqs): - if len(expanded_set) == 0: - callback_seqs.append(seq) - if not set(seq).issubset(allowed_bases): - continue - alternatives = [iupac[base] for base in seq] - for tuple_product in itertools.product(*alternatives): - expanded_str = "".join(tuple_product) - if expanded_str not in expanded_set: - expanded_set.add(expanded_str) - expanded_seqs.append(expanded_str) - - if len(expanded_set) == 0: - logging.warning( - "WARNING: Every sequence must have contained an N in this slice - redo sequence curation because this is nonsense" - ) - logging.warning(f'Sequences were: {" ".join(callback_seqs)}') - logging.warning( - "Using these sequences anyway, and should be ignored downstream" - ) - return callback_seqs - return expanded_seqs - - -def enforce_multisequence_nonmatch_intervals( - match_intervals, non_match_intervals, alignment: MSA -): - """ - Goes through non-match intervals and makes sure there is more than one sequence there, else makes it a match - interval. - Example reasons for such a conversion to occur: - - 'N' in a sequence causes it to be filtered out, and left with a single useable sequence - - '-' in sequences causes them to appear different, but they are the same - """ - for i in reversed(range(len(non_match_intervals))): - interval = non_match_intervals[i] - interval_alignment = alignment[:, interval[0] : interval[1] + 1] - interval_seqs = get_interval_seqs(interval_alignment) - if len(interval_seqs) < 2: - match_intervals.append(non_match_intervals[i]) - non_match_intervals.pop(i) - match_intervals.sort() - - -def enforce_alignment_interval_bijection( - match_intervals, non_match_intervals, alignment_length: int -): - """ - Check each position in an alignment is in one, and one only, (match or non_match) interval - """ - for i in range(alignment_length): - count_match = 0 - for interval in match_intervals: - if interval[0] <= i <= interval[1]: - count_match += 1 - count_non_match = 0 - for interval in non_match_intervals: - if interval[0] <= i <= interval[1]: - count_non_match += 1 - - if count_match > 1 or count_non_match > 1: - raise PartitioningError( - f"Failed interval partitioning: position {i}" - " appears in more than one interval" - ) - if not count_match ^ count_non_match: # test fails if they are the same integer - msg = ["neither", "nor"] if count_match == 0 else ["both", "and"] - raise PartitioningError( - "Failed interval partitioning: alignment position %d" - "classified as %s match %s non-match " % (i, *msg) - ) diff --git a/tests/test_make_prg_from_msa.py b/tests/test_make_prg_from_msa.py index 7d150c9..4db23fc 100644 --- a/tests/test_make_prg_from_msa.py +++ b/tests/test_make_prg_from_msa.py @@ -8,8 +8,8 @@ from make_prg import MSA from make_prg.make_prg_from_msa import AlignedSeq -from make_prg.utils import ( - standard_bases, +from make_prg.seq_utils import standard_bases +from make_prg.interval_partition import ( enforce_multisequence_nonmatch_intervals, enforce_alignment_interval_bijection, PartitioningError, @@ -113,39 +113,50 @@ def test_bijection_respected_passes(self): class TestIntervalPartitioning(TestCase): def test_all_non_match(self, get_consensus, _, __): get_consensus.return_value = "******" - tester = AlignedSeq("_", alignment="_", min_match_length=3) + tester = AlignedSeq("_", alignment=MSA([]), min_match_length=3) match, non_match, _ = tester.partition_alignment_into_intervals() self.assertEqual(match, []) self.assertEqual(non_match, [[0, 5]]) def test_all_match(self, get_consensus, _, __): get_consensus.return_value = "ATATAAA" - tester = AlignedSeq("_", alignment="_", min_match_length=3) + tester = AlignedSeq("_", alignment=MSA([]), min_match_length=3) match, non_match, _ = tester.partition_alignment_into_intervals() self.assertEqual(match, [[0, 6]]) self.assertEqual(non_match, []) def test_short_match_counted_as_non_match(self, get_consensus, _, __): get_consensus.return_value = "AT***" - tester = AlignedSeq("_", alignment="_", min_match_length=3) + tester = AlignedSeq("_", alignment=MSA([]), min_match_length=3) match, non_match, _ = tester.partition_alignment_into_intervals() self.assertEqual(match, []) self.assertEqual(non_match, [[0, 4]]) def test_match_non_match_match(self, get_consensus, _, __): get_consensus.return_value = "ATT**AAAC" - tester = AlignedSeq("_", alignment="_", min_match_length=3) + tester = AlignedSeq("_", alignment=MSA([]), min_match_length=3) match, non_match, _ = tester.partition_alignment_into_intervals() self.assertEqual(match, [[0, 2], [5, 8]]) self.assertEqual(non_match, [[3, 4]]) def test_end_in_non_match(self, get_consensus, _, __): get_consensus.return_value = "**ATT**AAA*C" - tester = AlignedSeq("_", alignment="_", min_match_length=3) + tester = AlignedSeq("_", alignment=MSA([]), min_match_length=3) match, non_match, _ = tester.partition_alignment_into_intervals() self.assertEqual(match, [[2, 4], [7, 9]]) self.assertEqual(non_match, [[0, 1], [5, 6], [10, 11]]) + def test_consensus_smaller_than_min_match_len(self, get_consensus, _, __): + """ + Usually, a match smaller than min_match_length counts as non-match, + but if the whole string is smaller than min_match_length, counts as match. + """ + get_consensus.return_value = "TTATT" + tester = AlignedSeq("_", alignment=MSA([]), min_match_length=7) + match, non_match, _ = tester.partition_alignment_into_intervals() + self.assertEqual(match, [[0, 4]]) + self.assertEqual(non_match, []) + class TestKmeansClusters(TestCase): def test_one_seq_returns_single_id(self): diff --git a/tests/test_seq_utils.py b/tests/test_seq_utils.py index 3282952..678f0c2 100644 --- a/tests/test_seq_utils.py +++ b/tests/test_seq_utils.py @@ -4,7 +4,7 @@ from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord -from make_prg.utils import get_interval_seqs +from make_prg.seq_utils import get_interval_seqs class TestGetIntervals(unittest.TestCase): From acff19b3315063a512407ebcdc3ff9d859e733fc Mon Sep 17 00:00:00 2001 From: Brice Letcher Date: Wed, 12 Aug 2020 10:40:50 +0100 Subject: [PATCH 7/9] Refactored interval partitioning (2) * Make unit tests work for new interval partitioning code * Use new interval partitioning code in make_prg code --- make_prg/interval_partition.py | 192 ++++++++++++++++++++----------- make_prg/make_prg_from_msa.py | 102 +++------------- make_prg/seq_utils.py | 6 + tests/__init__.py | 14 +++ tests/test_interval_partition.py | 129 +++++++++++++++++++++ tests/test_make_prg_from_msa.py | 114 +----------------- 6 files changed, 293 insertions(+), 264 deletions(-) create mode 100644 tests/test_interval_partition.py diff --git a/make_prg/interval_partition.py b/make_prg/interval_partition.py index 6e1cd52..c8451b0 100644 --- a/make_prg/interval_partition.py +++ b/make_prg/interval_partition.py @@ -3,11 +3,15 @@ match/non_match intervals. """ from enum import Enum, auto -from typing import List, Tuple +from typing import List, Tuple, Optional from make_prg import MSA -from make_prg.seq_utils import get_interval_seqs +from make_prg.seq_utils import get_interval_seqs, is_non_match + + +class PartitioningError(Exception): + pass class IntervalType(Enum): @@ -33,18 +37,30 @@ class Interval: def __init__(self, it_type: IntervalType, start: int, stop: int = None): self.type = it_type self.start = start + if stop is not None: + assert stop >= start self.stop = stop if stop is not None else start def modify_by(self, left_delta: int, right_delta: int): self.start += left_delta self.stop += right_delta + def contains(self, position: int): + return self.start <= position <= self.stop + def __len__(self) -> int: return self.stop - self.start + 1 def __lt__(self, other: "Interval") -> bool: return self.start < other.start + def __eq__(self, other: "Interval") -> bool: + return ( + self.start == other.start + and self.stop == other.stop + and self.type is other.type + ) + Intervals = List[Interval] @@ -59,15 +75,36 @@ def __init__(self, consensus_string: str, min_match_length: int, alignment: MSA) self._non_match_intervals: Intervals = list() self.mml = min_match_length - cur_interval = self._new_interval(consensus_string[0], 0) + if len(consensus_string) < self.mml: + # In this case, a match of less than the min_match_length gets counted + # as a match (usually, it counts as a non_match) + it_type = IntervalType.Match + if any(map(is_non_match, consensus_string)): + it_type = IntervalType.NonMatch + self._append(Interval(it_type, 0, len(consensus_string) - 1)) - for i, letter in enumerate(consensus_string[1:], start=1): - if is_type(letter, cur_interval.type): - cur_interval.modify_by(0, 1) # simple interval extension - else: - self._add_interval(cur_interval, alignment) - cur_interval = self._new_interval(letter, i) - self._add_interval(cur_interval, alignment) + else: + cur_interval = self._new_interval(consensus_string[0], 0) + + for i, letter in enumerate(consensus_string[1:], start=1): + if is_type(letter, cur_interval.type): + cur_interval.modify_by(0, 1) # simple interval extension + else: + new_interval = self._add_interval(cur_interval, alignment) + if new_interval is None: + cur_interval = self._new_interval(letter, i) + else: + cur_interval = new_interval + self._add_interval(cur_interval, alignment, end=True) + + self.enforce_multisequence_nonmatch_intervals( + self._match_intervals, self._non_match_intervals, alignment + ) + self.enforce_alignment_interval_bijection( + self._match_intervals, + self._non_match_intervals, + alignment.get_alignment_length(), + ) def get_intervals(self) -> Tuple[Intervals, Intervals, Intervals]: return ( @@ -91,71 +128,90 @@ def _pop(self, it_type: IntervalType) -> Interval: else: return self._non_match_intervals.pop() - def _add_interval(self, interval: Interval, alignment: MSA): + def _add_interval( + self, interval: Interval, alignment: MSA, end: bool = False + ) -> Optional[Interval]: + """ + If we are given a match interval < min_match_length, we return a new or extended + non_match interval + """ if interval.type is IntervalType.Match: + # The +1 is because we also extend the non_match interval if len(interval) < self.mml: try: last_non_match = self._pop(IntervalType.NonMatch) - last_non_match.modify_by(0, len(interval)) + last_non_match.modify_by(0, len(interval) + 1) except IndexError: last_non_match = Interval( - interval.start, interval.stop, IntervalType.NonMatch + IntervalType.NonMatch, interval.start, interval.stop + 1 ) - self.append(last_non_match) - return + if end: # If this is final call, go to append the interval + last_non_match.modify_by(0, -1) + self._append(last_non_match) + return last_non_match else: pass self._append(interval) - - -def enforce_multisequence_nonmatch_intervals( - match_intervals, non_match_intervals, alignment: MSA -): - """ - Goes through non-match intervals and makes sure there is more than one sequence there, else makes it a match - interval. - Example reasons for such a conversion to occur: - - 'N' in a sequence causes it to be filtered out, and left with a single useable sequence - - '-' in sequences causes them to appear different, but they are the same - """ - for i in reversed(range(len(non_match_intervals))): - interval = non_match_intervals[i] - interval_alignment = alignment[:, interval[0] : interval[1] + 1] - interval_seqs = get_interval_seqs(interval_alignment) - if len(interval_seqs) < 2: - match_intervals.append(non_match_intervals[i]) - non_match_intervals.pop(i) - match_intervals.sort() - - -class PartitioningError(Exception): - pass - - -def enforce_alignment_interval_bijection( - match_intervals, non_match_intervals, alignment_length: int -): - """ - Check each position in an alignment is in one, and one only, (match or non_match) interval - """ - for i in range(alignment_length): - count_match = 0 - for interval in match_intervals: - if interval[0] <= i <= interval[1]: - count_match += 1 - count_non_match = 0 - for interval in non_match_intervals: - if interval[0] <= i <= interval[1]: - count_non_match += 1 - - if count_match > 1 or count_non_match > 1: - raise PartitioningError( - f"Failed interval partitioning: position {i}" - " appears in more than one interval" - ) - if not count_match ^ count_non_match: # test fails if they are the same integer - msg = ["neither", "nor"] if count_match == 0 else ["both", "and"] - raise PartitioningError( - "Failed interval partitioning: alignment position %d" - "classified as %s match %s non-match " % (i, *msg) - ) + return None + + @classmethod + def enforce_multisequence_nonmatch_intervals( + cls, match_intervals: Intervals, non_match_intervals: Intervals, alignment: MSA + ): + """ + Goes through non-match intervals and makes sure there is more than one sequence there, else makes it a match + interval. + Example reasons for such a conversion to occur: + - 'N' in a sequence causes it to be filtered out, and left with a single useable sequence + - '-' in sequences causes them to appear different, but they are the same + """ + if len(alignment) == 0: # For testing convenience + return + for i in reversed(range(len(non_match_intervals))): + interval = non_match_intervals[i] + interval_alignment = alignment[:, interval.start : interval.stop + 1] + interval_seqs = get_interval_seqs(interval_alignment) + if len(interval_seqs) < 2: + changed_interval = non_match_intervals[i] + match_intervals.append( + Interval( + IntervalType.Match, + changed_interval.start, + changed_interval.stop, + ) + ) + non_match_intervals.pop(i) + + @classmethod + def enforce_alignment_interval_bijection( + cls, + match_intervals: Intervals, + non_match_intervals: Intervals, + alignment_length: int, + ): + """ + Check each position in an alignment is in one, and one only, (match or non_match) interval + """ + for i in range(alignment_length): + count_match = 0 + for interval in match_intervals: + if interval.contains(i): + count_match += 1 + count_non_match = 0 + for interval in non_match_intervals: + if interval.contains(i): + count_non_match += 1 + + if count_match > 1 or count_non_match > 1: + raise PartitioningError( + f"Failed interval partitioning: position {i}" + " appears in more than one interval" + ) + if ( + not count_match ^ count_non_match + ): # test fails if they are the same integer + msg = ["neither", "nor"] if count_match == 0 else ["both", "and"] + raise PartitioningError( + "Failed interval partitioning: alignment position %d" + "classified as %s match %s non-match " % (i, msg[0], msg[1]) + ) diff --git a/make_prg/make_prg_from_msa.py b/make_prg/make_prg_from_msa.py index 648662d..f984d0a 100644 --- a/make_prg/make_prg_from_msa.py +++ b/make_prg/make_prg_from_msa.py @@ -12,10 +12,7 @@ remove_duplicates, get_interval_seqs, ) -from make_prg.interval_partition import ( - enforce_multisequence_nonmatch_intervals, - enforce_alignment_interval_bijection, -) +from make_prg.interval_partition import IntervalPartitioner class AlignedSeq(object): @@ -53,7 +50,9 @@ def __init__( self.match_intervals, self.non_match_intervals, self.all_intervals, - ) = self.partition_alignment_into_intervals() + ) = IntervalPartitioner( + self.consensus, self.min_match_length, self.alignment + ).get_intervals() # properties for stats self.subAlignedSeqs = {} @@ -91,75 +90,6 @@ def get_consensus(cls, alignment: MSA): return consensus_string - def partition_alignment_into_intervals(self): - """Return a list of intervals in which we have - consensus sequence longer than min_match_length, and - a list of the non-match intervals left.""" - match_intervals, non_match_intervals = list(), list() - match_count, match_start, non_match_start = 0, 0, 0 - - logging.debug("consensus: %s" % self.consensus) - for i in range(self.length): - if self.consensus[i] != "*": - # In a match region. - if match_count == 0: - match_start = i - match_count += 1 - else: - if match_count == 0: - continue - logging.debug( - "have match string %s" - % self.consensus[match_start : match_start + match_count] - ) - - if match_count >= self.min_match_length: - if non_match_start < match_start: - non_match_intervals.append([non_match_start, match_start - 1]) - logging.debug( - f"add non-match interval [{non_match_start},{match_start - 1}]" - ) - end = match_start + match_count - 1 - match_intervals.append([match_start, end]) - logging.debug(f"add match interval [{match_start},{end}]") - non_match_start = match_start = i - match_count = 0 - - end = self.length - 1 - if self.length < self.min_match_length: - # Special case: a short sequence can still get classified as a match interval - added_interval = "non_match" if "*" in self.consensus else "match" - if added_interval == "match": - match_intervals.append([0, end]) - else: - non_match_intervals.append([0, end]) - logging.debug(f"add whole short {added_interval} interval [0,{end}]") - match_count = 0 - non_match_start = end + 1 - - # At end add last intervals - if match_count > 0: - if match_count >= self.min_match_length: - match_intervals.append([match_start, end]) - logging.debug(f"add final match interval [{match_start},{end}]") - if non_match_start < match_start: - end = match_start - 1 - if match_count != self.length and non_match_start <= end: - non_match_intervals.append([non_match_start, end]) - logging.debug(f"add non-match interval [{non_match_start},{end}]") - - enforce_alignment_interval_bijection( - match_intervals, non_match_intervals, self.length - ) - - logging.info("Non match intervals: %s", non_match_intervals) - enforce_multisequence_nonmatch_intervals( - match_intervals, non_match_intervals, self.alignment - ) - all_intervals = match_intervals + non_match_intervals - all_intervals.sort() - return match_intervals, non_match_intervals, all_intervals - @classmethod def kmeans_cluster_seqs_in_interval( self, interval: List[int], alignment: MSA, min_match_length: int, @@ -303,7 +233,7 @@ def get_prg(self): if interval in self.match_intervals: # all seqs are not necessarily exactly the same: some can have 'N' # thus still process all of them, to get the one with no 'N'. - sub_alignment = self.alignment[:, interval[0] : interval[1] + 1] + sub_alignment = self.alignment[:, interval.start : interval.stop + 1] seqs = get_interval_seqs(sub_alignment) assert len(seqs) == 1, "Got >1 filtered sequences in match interval" seq = seqs[0] @@ -317,13 +247,15 @@ def get_prg(self): # Define the variant seqs to add if (self.nesting_level == self.max_nesting) or ( - interval[1] - interval[0] <= self.min_match_length + interval.stop - interval.start <= self.min_match_length ): logging.debug( "Have reached max nesting level or have a small variant site, so add all variant " "sequences in interval." ) - sub_alignment = self.alignment[:, interval[0] : interval[1] + 1] + sub_alignment = self.alignment[ + :, interval.start : interval.stop + 1 + ] variant_prgs = get_interval_seqs(sub_alignment) logging.debug(f"Variant seqs found: {variant_prgs}") else: @@ -332,11 +264,13 @@ def get_prg(self): ) recur = True id_lists = self.kmeans_cluster_seqs_in_interval( - interval, self.alignment, self.min_match_length + [interval.start, interval.stop], + self.alignment, + self.min_match_length, ) list_sub_alignments = [ self.get_sub_alignment_by_list_id( - id_list, self.alignment, interval + id_list, self.alignment, [interval.start, interval.stop] ) for id_list in id_lists ] @@ -347,8 +281,8 @@ def get_prg(self): "Clustering did not group any sequences together, each seq is a cluster" ) recur = False - elif interval[0] not in self.subAlignedSeqs: - self.subAlignedSeqs[interval[0]] = [] + elif interval.start not in self.subAlignedSeqs: + self.subAlignedSeqs[interval.start] = [] logging.debug( "subAlignedSeqs now has keys: %s", list(self.subAlignedSeqs.keys()), @@ -356,7 +290,7 @@ def get_prg(self): else: logging.debug( "subAlignedSeqs already had key %d in keys: %s. This shouldn't happen.", - interval[0], + interval.start, list(self.subAlignedSeqs.keys()), ) @@ -376,7 +310,7 @@ def get_prg(self): self.site = sub_aligned_seq.site if recur: - self.subAlignedSeqs[interval[0]].append(sub_aligned_seq) + self.subAlignedSeqs[interval.start].append(sub_aligned_seq) assert num_clusters == len(variant_prgs), ( "I don't seem to have a sub-prg sequence for all parts of the partition - there are %d " "classes in partition, and %d variant seqs" @@ -437,7 +371,7 @@ def max_nesting_level_reached(self): def prop_in_match_intervals(self): length_match_intervals = 0 for interval in self.match_intervals: - length_match_intervals += interval[1] - interval[0] + 1 + length_match_intervals += interval.stop - interval.start + 1 return length_match_intervals / float(self.length) @property diff --git a/make_prg/seq_utils.py b/make_prg/seq_utils.py index b08b2bf..13d0e4b 100644 --- a/make_prg/seq_utils.py +++ b/make_prg/seq_utils.py @@ -4,6 +4,12 @@ from Bio import AlignIO +NONMATCH = "*" + + +def is_non_match(letter: str): + return letter == NONMATCH + def remove_duplicates(seqs: Sequence) -> Generator: seen = set() diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..a4b2330 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,14 @@ +from typing import List + +from Bio.Seq import Seq +from Bio.SeqRecord import SeqRecord + +from make_prg import MSA + + +def make_alignment(seqs: List[str], ids: List[str] = None) -> MSA: + if ids is None: + seqrecords = [SeqRecord(Seq(seq), id=f"s_{i}") for i, seq in enumerate(seqs)] + else: + seqrecords = [SeqRecord(Seq(seq), id=ID) for seq, ID in zip(seqs, ids)] + return MSA(seqrecords) diff --git a/tests/test_interval_partition.py b/tests/test_interval_partition.py new file mode 100644 index 0000000..cf27470 --- /dev/null +++ b/tests/test_interval_partition.py @@ -0,0 +1,129 @@ +from unittest import TestCase +from typing import List + +from tests import make_alignment, MSA +from make_prg.interval_partition import ( + IntervalType, + Interval, + IntervalPartitioner, + PartitioningError, +) + +Lists = List[List[int]] + +Match = IntervalType.Match +NonMatch = IntervalType.NonMatch + + +def make_typed_intervals(lists: Lists, it_type: IntervalType): + result = list() + for elem in lists: + result.append(Interval(it_type, elem[0], elem[1])) + return result + + +def make_intervals(match_lists: Lists, non_match_lists: Lists): + return ( + make_typed_intervals(match_lists, Match), + make_typed_intervals(non_match_lists, NonMatch), + ) + + +class TestIntervalConsistency(TestCase): + def test_nonmatch_interval_switching_indels(self): + """Because the sequences are the same, despite different alignment""" + alignment = make_alignment(["A---A", "A-A--"]) + match_intervals, non_match_intervals = make_intervals([], [[0, 5]]) + IntervalPartitioner.enforce_multisequence_nonmatch_intervals( + match_intervals, non_match_intervals, alignment + ) + self.assertEqual(match_intervals, make_typed_intervals([[0, 5]], Match)) + self.assertEqual(non_match_intervals, []) + + def test_nonmatch_interval_switching_Ns(self): + """'N's make sequences get removed""" + alignment = make_alignment(["ANAAA", "ATAAT"]) + match_intervals, non_match_intervals = make_intervals([], [[0, 5]]) + IntervalPartitioner.enforce_multisequence_nonmatch_intervals( + match_intervals, non_match_intervals, alignment + ) + self.assertEqual(match_intervals, make_typed_intervals([[0, 5]], Match)) + self.assertEqual(non_match_intervals, []) + + def test_position_in_several_intervals_fails(self): + match_intervals = make_typed_intervals([[0, 1], [1, 2]], Match) + with self.assertRaises(PartitioningError): + IntervalPartitioner.enforce_alignment_interval_bijection( + match_intervals, [], 3 + ) + + def test_position_in_no_interval_fails(self): + match_intervals = make_typed_intervals([[0, 1]], Match) + with self.assertRaises(PartitioningError): + IntervalPartitioner.enforce_alignment_interval_bijection( + match_intervals, [], 3 + ) + + def test_position_in_match_and_nonmatch_intervals_fails(self): + match_intervals, nmatch_intervals = make_intervals([[0, 2]], [[2, 3]]) + with self.assertRaises(PartitioningError): + IntervalPartitioner.enforce_alignment_interval_bijection( + match_intervals, nmatch_intervals, 4 + ) + + def test_bijection_respected_passes(self): + match_intervals, nmatch_intervals = make_intervals([[0, 2], [5, 10]], [[3, 4]]) + IntervalPartitioner.enforce_alignment_interval_bijection( + match_intervals, nmatch_intervals, 11 + ) + + +class TestIntervalPartitioning(TestCase): + def test_all_non_match(self): + tester = IntervalPartitioner("******", min_match_length=3, alignment=MSA([])) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, []) + self.assertEqual(non_match, make_typed_intervals([[0, 5]], NonMatch)) + + def test_all_match(self): + tester = IntervalPartitioner("ATATAAA", min_match_length=3, alignment=MSA([])) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, make_typed_intervals([[0, 6]], Match)) + self.assertEqual(non_match, []) + + def test_short_match_counted_as_non_match(self): + tester = IntervalPartitioner("AT***", min_match_length=3, alignment=MSA([])) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, []) + self.assertEqual(non_match, make_typed_intervals([[0, 4]], NonMatch)) + + def test_match_non_match_match(self): + tester = IntervalPartitioner("ATT**AAAC", min_match_length=3, alignment=MSA([])) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, make_typed_intervals([[0, 2], [5, 8]], Match)) + self.assertEqual(non_match, make_typed_intervals([[3, 4]], NonMatch)) + + def test_end_in_non_match(self): + tester = IntervalPartitioner( + "**ATT**AAA*C", min_match_length=3, alignment=MSA([]) + ) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, make_typed_intervals([[2, 4], [7, 9]], Match)) + self.assertEqual( + non_match, make_typed_intervals([[0, 1], [5, 6], [10, 11]], NonMatch) + ) + + def test_consensus_smaller_than_min_match_len(self): + """ + Usually, a match smaller than min_match_length counts as non-match, + but if the whole string is smaller than min_match_length, counts as match. + """ + tester1 = IntervalPartitioner("TTATT", min_match_length=7, alignment=MSA([])) + match, non_match, _ = tester1.get_intervals() + self.assertEqual(match, make_typed_intervals([[0, 4]], Match)) + self.assertEqual(non_match, []) + + tester2 = IntervalPartitioner("T*ATT", min_match_length=7, alignment=MSA([])) + match, non_match, _ = tester2.get_intervals() + self.assertEqual(match, []) + self.assertEqual(non_match, make_typed_intervals([[0, 4]], NonMatch)) diff --git a/tests/test_make_prg_from_msa.py b/tests/test_make_prg_from_msa.py index 4db23fc..9ec6f58 100644 --- a/tests/test_make_prg_from_msa.py +++ b/tests/test_make_prg_from_msa.py @@ -1,32 +1,18 @@ import os import random -from unittest import TestCase, mock, skip -from typing import List +from unittest import TestCase, skip from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord -from make_prg import MSA from make_prg.make_prg_from_msa import AlignedSeq from make_prg.seq_utils import standard_bases -from make_prg.interval_partition import ( - enforce_multisequence_nonmatch_intervals, - enforce_alignment_interval_bijection, - PartitioningError, -) +from tests import make_alignment, MSA this_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) data_dir = os.path.join(this_dir, "tests", "data", "make_prg_from_msa") -def make_alignment(seqs: List[str], ids: List[str] = None) -> MSA: - if ids is None: - seqrecords = [SeqRecord(Seq(seq), id=f"s_{i}") for i, seq in enumerate(seqs)] - else: - seqrecords = [SeqRecord(Seq(seq), id=ID) for seq, ID in zip(seqs, ids)] - return MSA(seqrecords) - - class TestConsensusString(TestCase): def test_all_match(self): alignment = make_alignment(["AATTA", "AATTA"]) @@ -62,102 +48,6 @@ def test_all_gap_nonmatch(self): self.assertEqual(result, "A**A") -class TestIntervalConsistency(TestCase): - def test_nonmatch_interval_switching_indels(self): - """Because the sequences are the same, despite different alignment""" - alignment = make_alignment(["A---A", "A-A--"]) - match_intervals = [] - non_match_intervals = [[0, 5]] - enforce_multisequence_nonmatch_intervals( - match_intervals, non_match_intervals, alignment - ) - self.assertEqual(match_intervals, [[0, 5]]) - self.assertEqual(non_match_intervals, []) - - def test_nonmatch_interval_switching_Ns(self): - """'N's make sequences get removed""" - alignment = make_alignment(["ANAAA", "ATAAT"]) - match_intervals = [] - non_match_intervals = [[0, 5]] - enforce_multisequence_nonmatch_intervals( - match_intervals, non_match_intervals, alignment - ) - self.assertEqual(match_intervals, [[0, 5]]) - self.assertEqual(non_match_intervals, []) - - def test_position_in_several_intervals_fails(self): - match_intervals = [[0, 1], [1, 2]] - with self.assertRaises(PartitioningError): - enforce_alignment_interval_bijection(match_intervals, [], 3) - - def test_position_in_no_interval_fails(self): - match_intervals = [[0, 1]] - with self.assertRaises(PartitioningError): - enforce_alignment_interval_bijection(match_intervals, [], 3) - - def test_position_in_match_and_nonmatch_intervals_fails(self): - match_intervals = [[0, 2]] - nmatch_intervals = [[2, 3]] - with self.assertRaises(PartitioningError): - enforce_alignment_interval_bijection(match_intervals, nmatch_intervals, 4) - - def test_bijection_respected_passes(self): - match_intervals = [[0, 2], [5, 10]] - nmatch_intervals = [[3, 4]] - enforce_alignment_interval_bijection(match_intervals, nmatch_intervals, 11) - - -@mock.patch("make_prg.make_prg_from_msa.enforce_multisequence_nonmatch_intervals") -@mock.patch.object(AlignedSeq, "get_prg") -@mock.patch.object(AlignedSeq, "get_consensus") -class TestIntervalPartitioning(TestCase): - def test_all_non_match(self, get_consensus, _, __): - get_consensus.return_value = "******" - tester = AlignedSeq("_", alignment=MSA([]), min_match_length=3) - match, non_match, _ = tester.partition_alignment_into_intervals() - self.assertEqual(match, []) - self.assertEqual(non_match, [[0, 5]]) - - def test_all_match(self, get_consensus, _, __): - get_consensus.return_value = "ATATAAA" - tester = AlignedSeq("_", alignment=MSA([]), min_match_length=3) - match, non_match, _ = tester.partition_alignment_into_intervals() - self.assertEqual(match, [[0, 6]]) - self.assertEqual(non_match, []) - - def test_short_match_counted_as_non_match(self, get_consensus, _, __): - get_consensus.return_value = "AT***" - tester = AlignedSeq("_", alignment=MSA([]), min_match_length=3) - match, non_match, _ = tester.partition_alignment_into_intervals() - self.assertEqual(match, []) - self.assertEqual(non_match, [[0, 4]]) - - def test_match_non_match_match(self, get_consensus, _, __): - get_consensus.return_value = "ATT**AAAC" - tester = AlignedSeq("_", alignment=MSA([]), min_match_length=3) - match, non_match, _ = tester.partition_alignment_into_intervals() - self.assertEqual(match, [[0, 2], [5, 8]]) - self.assertEqual(non_match, [[3, 4]]) - - def test_end_in_non_match(self, get_consensus, _, __): - get_consensus.return_value = "**ATT**AAA*C" - tester = AlignedSeq("_", alignment=MSA([]), min_match_length=3) - match, non_match, _ = tester.partition_alignment_into_intervals() - self.assertEqual(match, [[2, 4], [7, 9]]) - self.assertEqual(non_match, [[0, 1], [5, 6], [10, 11]]) - - def test_consensus_smaller_than_min_match_len(self, get_consensus, _, __): - """ - Usually, a match smaller than min_match_length counts as non-match, - but if the whole string is smaller than min_match_length, counts as match. - """ - get_consensus.return_value = "TTATT" - tester = AlignedSeq("_", alignment=MSA([]), min_match_length=7) - match, non_match, _ = tester.partition_alignment_into_intervals() - self.assertEqual(match, [[0, 4]]) - self.assertEqual(non_match, []) - - class TestKmeansClusters(TestCase): def test_one_seq_returns_single_id(self): alignment = MSA([SeqRecord(Seq("AAAT"), id="s1")]) From 13e5a52c9e5fe3db2956941892fc7b2505083fe6 Mon Sep 17 00:00:00 2001 From: Brice Letcher Date: Wed, 12 Aug 2020 14:08:15 +0100 Subject: [PATCH 8/9] Centralised setting of verbosity at command line --- make_prg/__main__.py | 15 +++++++++++++++ make_prg/interval_partition.py | 8 ++++++-- make_prg/make_prg_from_msa.py | 8 +++++++- make_prg/seq_utils.py | 2 +- make_prg/subcommands/prg_from_msa.py | 27 +++++++-------------------- 5 files changed, 36 insertions(+), 24 deletions(-) diff --git a/make_prg/__main__.py b/make_prg/__main__.py index edc8276..8df8575 100644 --- a/make_prg/__main__.py +++ b/make_prg/__main__.py @@ -1,4 +1,5 @@ import argparse +import logging from make_prg import __version__ from make_prg.subcommands import prg_from_msa @@ -16,10 +17,24 @@ def main(): title="Available subcommands", help="", metavar="" ) + parser.add_argument( + "-v", + "--verbose", + dest="verbose", + action="store_true", + help="Run with high verbosity " "(debug level logging)", + ) + prg_from_msa.register_parser(subparsers) args = parser.parse_args() + if args.verbose: + log_level = logging.DEBUG + else: + log_level = logging.INFO + logging.basicConfig(level=log_level, handlers=[]) + if hasattr(args, "func"): args.func(args) else: diff --git a/make_prg/interval_partition.py b/make_prg/interval_partition.py index c8451b0..80077c8 100644 --- a/make_prg/interval_partition.py +++ b/make_prg/interval_partition.py @@ -61,6 +61,9 @@ def __eq__(self, other: "Interval") -> bool: and self.type is other.type ) + def __repr__(self): + return f"[{self.start}, {self.stop}]" + Intervals = List[Interval] @@ -132,8 +135,9 @@ def _add_interval( self, interval: Interval, alignment: MSA, end: bool = False ) -> Optional[Interval]: """ - If we are given a match interval < min_match_length, we return a new or extended - non_match interval + i)If we are given a match interval < min_match_length, we return an extended non_match interval + ii)If we are given a non_match interval containing 1+ empty sequence, we pad it with + previous match_interval, if any, to avoid empty alleles in resulting prg. """ if interval.type is IntervalType.Match: # The +1 is because we also extend the non_match interval diff --git a/make_prg/make_prg_from_msa.py b/make_prg/make_prg_from_msa.py index f984d0a..3e0137f 100644 --- a/make_prg/make_prg_from_msa.py +++ b/make_prg/make_prg_from_msa.py @@ -11,6 +11,7 @@ ambiguous_bases, remove_duplicates, get_interval_seqs, + NONMATCH, ) from make_prg.interval_partition import IntervalPartitioner @@ -53,6 +54,11 @@ def __init__( ) = IntervalPartitioner( self.consensus, self.min_match_length, self.alignment ).get_intervals() + logging.info( + "match intervals: %s; non_match intervals: %s", + self.match_intervals, + self.non_match_intervals, + ) # properties for stats self.subAlignedSeqs = {} @@ -84,7 +90,7 @@ def get_consensus(cls, alignment: MSA): or len(column) != 1 or column == {"-"} ): - consensus_string += "*" + consensus_string += NONMATCH else: consensus_string += column.pop() diff --git a/make_prg/seq_utils.py b/make_prg/seq_utils.py index 13d0e4b..a58349a 100644 --- a/make_prg/seq_utils.py +++ b/make_prg/seq_utils.py @@ -62,7 +62,7 @@ def get_interval_seqs(interval_alignment: AlignIO.MultipleSeqAlignment): if len(expanded_set) == 0: logging.warning( - "WARNING: Every sequence must have contained an N in this slice - redo sequence curation because this is nonsense" + "WARNING: Every sequence must have contained an N in this slice - redo sequence curation" ) logging.warning(f'Sequences were: {" ".join(callback_seqs)}') logging.warning( diff --git a/make_prg/subcommands/prg_from_msa.py b/make_prg/subcommands/prg_from_msa.py index faa7c32..4738b58 100644 --- a/make_prg/subcommands/prg_from_msa.py +++ b/make_prg/subcommands/prg_from_msa.py @@ -63,13 +63,6 @@ def register_parser(subparsers): action="store_true", help="Do not overwrite pre-existing prg file with same name", ) - subparser_prg_from_msa.add_argument( - "-v", - "--verbose", - dest="verbose", - action="store_true", - help="Run with high verbosity " "(debug level logging)", - ) subparser_prg_from_msa.set_defaults(func=run) return subparser_prg_from_msa @@ -88,23 +81,17 @@ def run(options): options.min_match_length, ) - if options.verbose: - log_level = logging.DEBUG - msg = "Using debug logging" - else: - log_level = logging.INFO - msg = "Using info logging" - + # Set up file logging log_file = f"{prefix}.log" if os.path.exists(log_file): os.unlink(log_file) - logging.basicConfig( - filename=log_file, - level=log_level, - format="%(asctime)s %(message)s", - datefmt="%d/%m/%Y %I:%M:%S", + formatter = logging.Formatter( + fmt="%(levelname)s %(asctime)s %(message)s", datefmt="%d/%m/%Y %I:%M:%S" ) - logging.info(msg) + handler = logging.FileHandler(log_file) + handler.setFormatter(formatter) + logging.getLogger().addHandler(handler) + logging.info( "Input parameters max_nesting: %d, min_match_length: %d", options.max_nesting, From 73af160368fd91e308f26f2d8180a24789eb6b3c Mon Sep 17 00:00:00 2001 From: Brice Letcher Date: Wed, 12 Aug 2020 15:39:07 +0100 Subject: [PATCH 9/9] Interval partitioning to reduce empty alleles Addresses #17 --- make_prg/interval_partition.py | 41 +++++++++++++++++++++-------- make_prg/seq_utils.py | 19 +++++++++++--- tests/test_interval_partition.py | 44 +++++++++++++++++++++++++++++--- tests/test_seq_utils.py | 13 +++++++--- 4 files changed, 97 insertions(+), 20 deletions(-) diff --git a/make_prg/interval_partition.py b/make_prg/interval_partition.py index 80077c8..0511661 100644 --- a/make_prg/interval_partition.py +++ b/make_prg/interval_partition.py @@ -7,7 +7,7 @@ from make_prg import MSA -from make_prg.seq_utils import get_interval_seqs, is_non_match +from make_prg.seq_utils import get_interval_seqs, is_non_match, has_empty_sequence class PartitioningError(Exception): @@ -18,22 +18,24 @@ class IntervalType(Enum): Match = auto() NonMatch = auto() - -def get_type(letter: str) -> IntervalType: - if letter == "*": - return IntervalType.NonMatch - else: - return IntervalType.Match + @classmethod + def from_char(cls, letter: str) -> "IntervalType": + if letter == "*": + return IntervalType.NonMatch + else: + return IntervalType.Match def is_type(letter: str, interval_type: IntervalType) -> bool: - if get_type(letter) is interval_type: + if IntervalType.from_char(letter) is interval_type: return True else: return False class Interval: + """Stores a closed interval [a,b]""" + def __init__(self, it_type: IntervalType, start: int, stop: int = None): self.type = it_type self.start = start @@ -117,7 +119,7 @@ def get_intervals(self) -> Tuple[Intervals, Intervals, Intervals]: ) def _new_interval(self, letter: str, start_pos: int) -> Interval: - return Interval(get_type(letter), start_pos) + return Interval(IntervalType.from_char(letter), start_pos) def _append(self, interval: Interval): if interval.type is IntervalType.Match: @@ -154,17 +156,34 @@ def _add_interval( self._append(last_non_match) return last_non_match else: - pass + if len(self._match_intervals) > 0 and has_empty_sequence( + alignment, (interval.start, interval.stop) + ): + # Pad interval with sequence to avoid empty alleles + len_match = len(self._match_intervals[-1]) + if len_match - 1 < self.mml: + # Case: match is now too small, converted to non_match + self._match_intervals.pop() + interval.modify_by(-1 * len_match, 0) + if len(self._non_match_intervals) > 0: + # Case: merge previous non_match with this non_match + self._non_match_intervals[-1].modify_by(0, len(interval)) + return None + else: + self._match_intervals[-1].modify_by(0, -1) + interval.modify_by(-1, 0) + self._append(interval) return None @classmethod def enforce_multisequence_nonmatch_intervals( cls, match_intervals: Intervals, non_match_intervals: Intervals, alignment: MSA - ): + ) -> None: """ Goes through non-match intervals and makes sure there is more than one sequence there, else makes it a match interval. + Modifies the intervals in-place. Example reasons for such a conversion to occur: - 'N' in a sequence causes it to be filtered out, and left with a single useable sequence - '-' in sequences causes them to appear different, but they are the same diff --git a/make_prg/seq_utils.py b/make_prg/seq_utils.py index a58349a..0139272 100644 --- a/make_prg/seq_utils.py +++ b/make_prg/seq_utils.py @@ -1,16 +1,21 @@ import logging -from typing import Generator, Sequence +from typing import Generator, Sequence, Tuple import itertools -from Bio import AlignIO +from make_prg import MSA NONMATCH = "*" +GAP = "-" def is_non_match(letter: str): return letter == NONMATCH +def is_gap(letter: str): + return letter == GAP + + def remove_duplicates(seqs: Sequence) -> Generator: seen = set() for x in seqs: @@ -37,7 +42,15 @@ def remove_duplicates(seqs: Sequence) -> Generator: ambiguous_bases = allowed_bases.difference(standard_bases) -def get_interval_seqs(interval_alignment: AlignIO.MultipleSeqAlignment): +def has_empty_sequence(alignment: MSA, interval: Tuple[int, int]) -> bool: + sub_alignment = alignment[:, interval[0] : interval[1] + 1] + for record in sub_alignment: + if all(map(is_gap, record.seq)): + return True + return False + + +def get_interval_seqs(interval_alignment: MSA): """ Replace - with nothing, remove seqs containing N or other non-allowed letters and duplicate sequences containing RYKMSW, replacing with AGCT alternatives diff --git a/tests/test_interval_partition.py b/tests/test_interval_partition.py index cf27470..d4ca450 100644 --- a/tests/test_interval_partition.py +++ b/tests/test_interval_partition.py @@ -99,9 +99,16 @@ def test_short_match_counted_as_non_match(self): def test_match_non_match_match(self): tester = IntervalPartitioner("ATT**AAAC", min_match_length=3, alignment=MSA([])) - match, non_match, _ = tester.get_intervals() - self.assertEqual(match, make_typed_intervals([[0, 2], [5, 8]], Match)) - self.assertEqual(non_match, make_typed_intervals([[3, 4]], NonMatch)) + match, non_match, all_match = tester.get_intervals() + expected_matches = make_typed_intervals([[0, 2], [5, 8]], Match) + expected_non_matches = make_typed_intervals([[3, 4]], NonMatch) + self.assertEqual(match, expected_matches) + self.assertEqual(non_match, expected_non_matches) + # Check interval sorting works + self.assertEqual( + all_match, + [expected_matches[0], expected_non_matches[0], expected_matches[1]], + ) def test_end_in_non_match(self): tester = IntervalPartitioner( @@ -127,3 +134,34 @@ def test_consensus_smaller_than_min_match_len(self): match, non_match, _ = tester2.get_intervals() self.assertEqual(match, []) self.assertEqual(non_match, make_typed_intervals([[0, 4]], NonMatch)) + + def test_avoid_empty_alleles_long_match(self): + """ + If we let the non-match interval be only [4,5], + this would result in an empty allele in the prg, + so require padding using the preceding match sequence + """ + msa = make_alignment(["TTAAGGTTT", "TTAA--TTT"]) + tester = IntervalPartitioner("TTAA**TTT", min_match_length=3, alignment=msa) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, make_typed_intervals([[0, 2], [6, 8]], Match)) + self.assertEqual(non_match, make_typed_intervals([[3, 5]], NonMatch)) + + def test_avoid_empty_alleles_short_match(self): + """ + Padding behaviour also expected, but now the leading match interval becomes too + short and collapses to a non_match interval + """ + msa = make_alignment(["TTAGGTTT", "TTA--TTT"]) + tester = IntervalPartitioner("TTA**TTT", min_match_length=3, alignment=msa) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, make_typed_intervals([[5, 7]], Match)) + self.assertEqual(non_match, make_typed_intervals([[0, 4]], NonMatch)) + + def test_avoid_empty_alleles_previous_non_match_merged(self): + """Edge case of collapsed match interval, part 2""" + msa = make_alignment(["CCTTAGGTTT", "AATTA--TTT"]) + tester = IntervalPartitioner("**TTA**TTT", min_match_length=3, alignment=msa) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, make_typed_intervals([[7, 9]], Match)) + self.assertEqual(non_match, make_typed_intervals([[0, 6]], NonMatch)) diff --git a/tests/test_seq_utils.py b/tests/test_seq_utils.py index 678f0c2..5804b7c 100644 --- a/tests/test_seq_utils.py +++ b/tests/test_seq_utils.py @@ -1,13 +1,14 @@ -import unittest +from unittest import TestCase from Bio import AlignIO from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord -from make_prg.seq_utils import get_interval_seqs +from tests import make_alignment +from make_prg.seq_utils import get_interval_seqs, has_empty_sequence -class TestGetIntervals(unittest.TestCase): +class TestGetIntervals(TestCase): def test_ambiguous_bases_one_seq(self): alignment = AlignIO.MultipleSeqAlignment([SeqRecord(Seq("RWAAT"))]) result = get_interval_seqs(alignment) @@ -27,3 +28,9 @@ def test_first_sequence_in_is_first_sequence_out(self): result = get_interval_seqs(alignment) expected = ["TTTT", "AAAA", "CCC"] self.assertEqual(expected, result) + + +class TestEmptySeq(TestCase): + def test_sub_alignment_with_empty_sequence(self): + msa = make_alignment(["TTAGGTTT", "TTA--TTT", "GGA-TTTT"]) + self.assertTrue(has_empty_sequence(msa, [3, 4]))