diff --git a/TTS/tts/utils/text/punctuation.py b/TTS/tts/utils/text/punctuation.py index 8d199cc545..36c467d083 100644 --- a/TTS/tts/utils/text/punctuation.py +++ b/TTS/tts/utils/text/punctuation.py @@ -15,7 +15,6 @@ class PuncPosition(Enum): BEGIN = 0 END = 1 MIDDLE = 2 - ALONE = 3 class Punctuation: @@ -92,7 +91,7 @@ def _strip_to_restore(self, text): return [text], [] # the text is only punctuations if len(matches) == 1 and matches[0].group() == text: - return [], [_PUNC_IDX(text, PuncPosition.ALONE)] + return [], [_PUNC_IDX(text, PuncPosition.BEGIN)] # build a punctuation map to be used later to restore punctuations puncs = [] for match in matches: @@ -107,11 +106,14 @@ def _strip_to_restore(self, text): for idx, punc in enumerate(puncs): split = text.split(punc.punc) prefix, suffix = split[0], punc.punc.join(split[1:]) + text = suffix + if prefix == "": + # We don't want to insert an empty string in case of initial punctuation + continue splitted_text.append(prefix) # if the text does not end with a punctuation, add it to the last item if idx == len(puncs) - 1 and len(suffix) > 0: splitted_text.append(suffix) - text = suffix return splitted_text, puncs @classmethod @@ -127,10 +129,10 @@ def restore(cls, text, puncs): ['This is', 'example'], ['.', '!'] -> "This is. example!" """ - return cls._restore(text, puncs, 0) + return cls._restore(text, puncs) @classmethod - def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements + def _restore(cls, text, puncs): # pylint: disable=too-many-return-statements """Auxiliary method for Punctuation.restore()""" if not puncs: return text @@ -142,21 +144,18 @@ def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statemen current = puncs[0] if current.position == PuncPosition.BEGIN: - return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num) + return cls._restore([current.punc + text[0]] + text[1:], puncs[1:]) if current.position == PuncPosition.END: - return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1) - - if current.position == PuncPosition.ALONE: - return [current.mark] + cls._restore(text, puncs[1:], num + 1) + return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:]) # POSITION == MIDDLE if len(text) == 1: # pragma: nocover # a corner case where the final part of an intermediate # mark (I) has not been phonemized - return cls._restore([text[0] + current.punc], puncs[1:], num) + return cls._restore([text[0] + current.punc], puncs[1:]) - return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num) + return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:]) # if __name__ == "__main__": diff --git a/tests/text_tests/test_punctuation.py b/tests/text_tests/test_punctuation.py index 141c10e48f..bb7b11edce 100644 --- a/tests/text_tests/test_punctuation.py +++ b/tests/text_tests/test_punctuation.py @@ -11,6 +11,11 @@ def setUp(self): ("This, is my text ... to be striped !! from text", "This is my text to be striped from text"), ("This, is my text ... to be striped from text?", "This is my text to be striped from text"), ("This, is my text to be striped from text", "This is my text to be striped from text"), + (".", ""), + (" . ", ""), + ("!!! Attention !!!", "Attention"), + ("!!! Attention !!! This is just a ... test.", "Attention This is just a test"), + ("!!! Attention! This is just a ... test.", "Attention This is just a test"), ] def test_get_set_puncs(self):