diff --git a/make_prg/interval_partition.py b/make_prg/interval_partition.py index 80077c8..0511661 100644 --- a/make_prg/interval_partition.py +++ b/make_prg/interval_partition.py @@ -7,7 +7,7 @@ from make_prg import MSA -from make_prg.seq_utils import get_interval_seqs, is_non_match +from make_prg.seq_utils import get_interval_seqs, is_non_match, has_empty_sequence class PartitioningError(Exception): @@ -18,22 +18,24 @@ class IntervalType(Enum): Match = auto() NonMatch = auto() - -def get_type(letter: str) -> IntervalType: - if letter == "*": - return IntervalType.NonMatch - else: - return IntervalType.Match + @classmethod + def from_char(cls, letter: str) -> "IntervalType": + if letter == "*": + return IntervalType.NonMatch + else: + return IntervalType.Match def is_type(letter: str, interval_type: IntervalType) -> bool: - if get_type(letter) is interval_type: + if IntervalType.from_char(letter) is interval_type: return True else: return False class Interval: + """Stores a closed interval [a,b]""" + def __init__(self, it_type: IntervalType, start: int, stop: int = None): self.type = it_type self.start = start @@ -117,7 +119,7 @@ def get_intervals(self) -> Tuple[Intervals, Intervals, Intervals]: ) def _new_interval(self, letter: str, start_pos: int) -> Interval: - return Interval(get_type(letter), start_pos) + return Interval(IntervalType.from_char(letter), start_pos) def _append(self, interval: Interval): if interval.type is IntervalType.Match: @@ -154,17 +156,34 @@ def _add_interval( self._append(last_non_match) return last_non_match else: - pass + if len(self._match_intervals) > 0 and has_empty_sequence( + alignment, (interval.start, interval.stop) + ): + # Pad interval with sequence to avoid empty alleles + len_match = len(self._match_intervals[-1]) + if len_match - 1 < self.mml: + # Case: match is now too small, converted to non_match + self._match_intervals.pop() + interval.modify_by(-1 * len_match, 0) + if len(self._non_match_intervals) > 0: + # Case: merge previous non_match with this non_match + self._non_match_intervals[-1].modify_by(0, len(interval)) + return None + else: + self._match_intervals[-1].modify_by(0, -1) + interval.modify_by(-1, 0) + self._append(interval) return None @classmethod def enforce_multisequence_nonmatch_intervals( cls, match_intervals: Intervals, non_match_intervals: Intervals, alignment: MSA - ): + ) -> None: """ Goes through non-match intervals and makes sure there is more than one sequence there, else makes it a match interval. + Modifies the intervals in-place. Example reasons for such a conversion to occur: - 'N' in a sequence causes it to be filtered out, and left with a single useable sequence - '-' in sequences causes them to appear different, but they are the same diff --git a/make_prg/seq_utils.py b/make_prg/seq_utils.py index a58349a..0139272 100644 --- a/make_prg/seq_utils.py +++ b/make_prg/seq_utils.py @@ -1,16 +1,21 @@ import logging -from typing import Generator, Sequence +from typing import Generator, Sequence, Tuple import itertools -from Bio import AlignIO +from make_prg import MSA NONMATCH = "*" +GAP = "-" def is_non_match(letter: str): return letter == NONMATCH +def is_gap(letter: str): + return letter == GAP + + def remove_duplicates(seqs: Sequence) -> Generator: seen = set() for x in seqs: @@ -37,7 +42,15 @@ def remove_duplicates(seqs: Sequence) -> Generator: ambiguous_bases = allowed_bases.difference(standard_bases) -def get_interval_seqs(interval_alignment: AlignIO.MultipleSeqAlignment): +def has_empty_sequence(alignment: MSA, interval: Tuple[int, int]) -> bool: + sub_alignment = alignment[:, interval[0] : interval[1] + 1] + for record in sub_alignment: + if all(map(is_gap, record.seq)): + return True + return False + + +def get_interval_seqs(interval_alignment: MSA): """ Replace - with nothing, remove seqs containing N or other non-allowed letters and duplicate sequences containing RYKMSW, replacing with AGCT alternatives diff --git a/tests/test_interval_partition.py b/tests/test_interval_partition.py index cf27470..d4ca450 100644 --- a/tests/test_interval_partition.py +++ b/tests/test_interval_partition.py @@ -99,9 +99,16 @@ def test_short_match_counted_as_non_match(self): def test_match_non_match_match(self): tester = IntervalPartitioner("ATT**AAAC", min_match_length=3, alignment=MSA([])) - match, non_match, _ = tester.get_intervals() - self.assertEqual(match, make_typed_intervals([[0, 2], [5, 8]], Match)) - self.assertEqual(non_match, make_typed_intervals([[3, 4]], NonMatch)) + match, non_match, all_match = tester.get_intervals() + expected_matches = make_typed_intervals([[0, 2], [5, 8]], Match) + expected_non_matches = make_typed_intervals([[3, 4]], NonMatch) + self.assertEqual(match, expected_matches) + self.assertEqual(non_match, expected_non_matches) + # Check interval sorting works + self.assertEqual( + all_match, + [expected_matches[0], expected_non_matches[0], expected_matches[1]], + ) def test_end_in_non_match(self): tester = IntervalPartitioner( @@ -127,3 +134,34 @@ def test_consensus_smaller_than_min_match_len(self): match, non_match, _ = tester2.get_intervals() self.assertEqual(match, []) self.assertEqual(non_match, make_typed_intervals([[0, 4]], NonMatch)) + + def test_avoid_empty_alleles_long_match(self): + """ + If we let the non-match interval be only [4,5], + this would result in an empty allele in the prg, + so require padding using the preceding match sequence + """ + msa = make_alignment(["TTAAGGTTT", "TTAA--TTT"]) + tester = IntervalPartitioner("TTAA**TTT", min_match_length=3, alignment=msa) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, make_typed_intervals([[0, 2], [6, 8]], Match)) + self.assertEqual(non_match, make_typed_intervals([[3, 5]], NonMatch)) + + def test_avoid_empty_alleles_short_match(self): + """ + Padding behaviour also expected, but now the leading match interval becomes too + short and collapses to a non_match interval + """ + msa = make_alignment(["TTAGGTTT", "TTA--TTT"]) + tester = IntervalPartitioner("TTA**TTT", min_match_length=3, alignment=msa) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, make_typed_intervals([[5, 7]], Match)) + self.assertEqual(non_match, make_typed_intervals([[0, 4]], NonMatch)) + + def test_avoid_empty_alleles_previous_non_match_merged(self): + """Edge case of collapsed match interval, part 2""" + msa = make_alignment(["CCTTAGGTTT", "AATTA--TTT"]) + tester = IntervalPartitioner("**TTA**TTT", min_match_length=3, alignment=msa) + match, non_match, _ = tester.get_intervals() + self.assertEqual(match, make_typed_intervals([[7, 9]], Match)) + self.assertEqual(non_match, make_typed_intervals([[0, 6]], NonMatch)) diff --git a/tests/test_seq_utils.py b/tests/test_seq_utils.py index 678f0c2..5804b7c 100644 --- a/tests/test_seq_utils.py +++ b/tests/test_seq_utils.py @@ -1,13 +1,14 @@ -import unittest +from unittest import TestCase from Bio import AlignIO from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord -from make_prg.seq_utils import get_interval_seqs +from tests import make_alignment +from make_prg.seq_utils import get_interval_seqs, has_empty_sequence -class TestGetIntervals(unittest.TestCase): +class TestGetIntervals(TestCase): def test_ambiguous_bases_one_seq(self): alignment = AlignIO.MultipleSeqAlignment([SeqRecord(Seq("RWAAT"))]) result = get_interval_seqs(alignment) @@ -27,3 +28,9 @@ def test_first_sequence_in_is_first_sequence_out(self): result = get_interval_seqs(alignment) expected = ["TTTT", "AAAA", "CCC"] self.assertEqual(expected, result) + + +class TestEmptySeq(TestCase): + def test_sub_alignment_with_empty_sequence(self): + msa = make_alignment(["TTAGGTTT", "TTA--TTT", "GGA-TTTT"]) + self.assertTrue(has_empty_sequence(msa, [3, 4]))