Skip to content

Commit

Permalink
Remove espeak
Browse files Browse the repository at this point in the history
  • Loading branch information
dnhkng committed Dec 10, 2024
1 parent 99fbf5e commit 559a4c4
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 31 deletions.
16 changes: 9 additions & 7 deletions demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
"metadata": {},
"outputs": [],
"source": [
"# Instantiate the TTS engine\n",
"glados_tts = tts.TTSEngine()"
"glados_tts = tts.Synthesizer(\n",
" model_path=str(\"models/glados.onnx\"),\n",
" use_cuda=True,\n",
" speaker_id=0,\n",
" )"
]
},
{
Expand All @@ -34,11 +37,10 @@
"outputs": [],
"source": [
"# Generate the audio.\n",
"# Glados is spelt incorrectly on purpose to make the pronunciation more accurate.\n",
"audio = glados_tts.generate_speech_audio(\"Hello, my name is Gladohs. I am an AI created by Aperture Science.\")\n",
"audio = glados_tts.generate_speech_audio(\"Hello, my name is Glados. I am an AI created by Aperture Science.\")\n",
"\n",
"# Play the audio\n",
"sd.play(audio, tts.RATE)"
"sd.play(audio, glados_tts.rate)"
]
},
{
Expand Down Expand Up @@ -88,7 +90,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "GLaDOS",
"display_name": "glados",
"language": "python",
"name": "python3"
},
Expand All @@ -102,7 +104,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
"version": "3.10.15"
}
},
"nbformat": 4,
Expand Down
295 changes: 295 additions & 0 deletions glados/phonemizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
from pickle import load
from typing import List, Iterable, Dict, Union
from functools import cache
from pathlib import Path
import re
from itertools import zip_longest

import numpy as np
import onnxruntime
from enum import Enum

class ModelConfig:
MODEL_NAME: str = 'models/phomenizer_en.onnx'
PHONEME_DICT_PATH: Path = Path('./models/lang_phoneme_dict.pkl')
TOKEN_TO_IDX_PATH: Path = Path('./models/token_to_idx.pkl')
IDX_TO_TOKEN_PATH: Path = Path('./models/idx_to_token.pkl')
CHAR_REPEATS: int = 3
MODEL_INPUT_LENGTH: int = 64
EXPAND_ACRONYMS: bool = True
USE_CUDA: bool = True

class SpecialTokens(Enum):
PAD = '_'
START = '<start>'
END = '<end>'
EN_US = '<en_us>'

class Punctuation(Enum):
PUNCTUATION = '().,:?!/–'
HYPHEN = '-'
SPACE = ' '

@classmethod
@cache
def get_punc_set(cls) -> set[str]:
return set(cls.PUNCTUATION.value + cls.HYPHEN.value + cls.SPACE.value)

@classmethod
@cache
def get_punc_pattern(cls) -> re.Pattern:
return re.compile(f'([{cls.PUNCTUATION.value + cls.SPACE.value}])')

class Phonemizer:
def __init__(self) -> None:
# self.ort_session = onnxruntime.InferenceSession(ModelConfig.MODEL_NAME)


self.phoneme_dict = self._load_pickle(ModelConfig.PHONEME_DICT_PATH)
self.token_to_idx = self._load_pickle(ModelConfig.TOKEN_TO_IDX_PATH)
self.idx_to_token = self._load_pickle(ModelConfig.IDX_TO_TOKEN_PATH)

providers = ["CPUExecutionProvider"]
if ModelConfig.USE_CUDA:
providers = [
("CUDAExecutionProvider", {"cudnn_conv_algo_search": "HEURISTIC"}),
"CPUExecutionProvider",
]

self.ort_session = onnxruntime.InferenceSession(
ModelConfig.MODEL_NAME,
sess_options=onnxruntime.SessionOptions(),
providers=providers,
)


self.special_tokens: set[str] = {SpecialTokens.PAD.value, SpecialTokens.END.value, SpecialTokens.EN_US.value}

@staticmethod
def _load_pickle(path: Path) -> dict:
"""Load a pickled dictionary from path."""
with path.open('rb') as f:
return load(f)

@staticmethod
def unique_consecutive(arr: np.ndarray) -> List[np.ndarray]:
"""
Equivalent to torch.unique_consecutive for numpy arrays.
:param arr: Array to process.
:return: Array with consecutive duplicates removed.
"""
result = []
for row in arr:
if len(row) == 0:
result.append(row)
else:
mask = np.concatenate(([True], row[1:] != row[:-1]))
result.append(row[mask])
return result

@staticmethod
def remove_padding(arr: np.ndarray, padding_value: int = 0) -> List[np.ndarray]:
return [row[row != padding_value] for row in arr]

@staticmethod
def trim_to_stop(arr: np.ndarray, end_index: int = 2) -> List[np.ndarray]:
"""
Trims the array to the stop index.
:param arr: Array to trim.
:param end_index: Index to stop at.
"""
result = []
for row in arr:
stop_index = np.where(row == end_index)[0]
if len(stop_index) > 0:
result.append(row[:stop_index[0]+1])
else:
result.append(row)
return result

def process_model_output(self, arr: np.ndarray) -> List[np.ndarray]:
"""
Processes the output of the model to get the phoneme indices.
:param arr: Model output.
:return: List of phoneme indices.
"""

arr = np.argmax(arr[0], axis=2)
arr = self.unique_consecutive(arr)
arr = self.remove_padding(arr)
arr = self.trim_to_stop(arr)
return arr

@staticmethod
def expand_acronym(word: str) -> str:
"""
Expands an acronym into its subwords.
:param word: Acronym to expand.
:return: Expanded acronym.
"""
subwords = []
for subword in word.split(Punctuation.HYPHEN.value):
expanded = []
for a, b in zip_longest(subword, subword[1:]):
expanded.append(a)
if b is not None and b.isupper():
expanded.append(Punctuation.HYPHEN.value)
expanded = ''.join(expanded)
subwords.append(expanded)
return Punctuation.HYPHEN.value.join(subwords)

def encode(self, sentence: Iterable[str]) -> List[int]:
"""
Maps a sequence of symbols for a language to a sequence of indices.
:param sentence: Sentence (or word) as a sequence of symbols.
:return: Sequence of token indices.
"""
sentence = [item for item in sentence for _ in range(ModelConfig.CHAR_REPEATS)]
sentence = [s.lower() for s in sentence]
sequence = [self.token_to_idx[c] for c in sentence if c in self.token_to_idx]
return [self.token_to_idx[SpecialTokens.START.value]] + sequence + [self.token_to_idx[SpecialTokens.END.value]]

def decode(self, sequence: np.ndarray) -> str:
"""
Maps a sequence of indices to an array of symbols.
:param sequence: Encoded sequence to be decoded.
:return: Decoded sequence of symbols.
"""
decoded = [self.idx_to_token[int(t)] for t in sequence if int(t) in self.idx_to_token]
return ''.join(d for d in decoded if d not in self.special_tokens)

@staticmethod
def pad_sequence_fixed(v: List[np.ndarray], target_length: int = ModelConfig.MODEL_INPUT_LENGTH) -> np.ndarray:
"""
Pad or truncate a list of arrays to a fixed length.
:param v: List of arrays.
:param target_length: Target length to pad or truncate to.
:return: Padded array.
"""
result = np.zeros((len(v), target_length), dtype=np.int64)

for i, seq in enumerate(v):
length = min(len(seq), target_length) # Handle both shorter and longer sequences
result[i, :length] = seq[:length] # Copy either the full sequence or its truncated version

return result


def _get_dict_entry(self, word: str, lang: str, punc_set: set[str]) -> str | None:
"""
Gets the phoneme entry for a word in the dictionary.
:param word: Word to get phoneme entry for.
:param lang: Language of the word.
:param punc_set: Set of punctuation characters.
:return: Phoneme entry for the word.
"""
if word in punc_set or len(word) == 0:
return word
if not self.phoneme_dict or lang not in self.phoneme_dict:
return None
phoneme_dict = self.phoneme_dict[lang]
if word in phoneme_dict:
return phoneme_dict[word]
elif word.lower() in phoneme_dict:
return phoneme_dict[word.lower()]
elif word.title() in phoneme_dict:
return phoneme_dict[word.title()]
else:
return None

@staticmethod
def _get_phonemes(word: str,
word_phonemes: Dict[str, Union[str, None]],
word_splits: Dict[str, List[str]]) -> str:
"""
Gets the phonemes for a word. If the word is not in the phoneme dictionary, it is split into subwords.
:param word: Word to get phonemes for.
:param word_phonemes: Dictionary of word phonemes.
"""
phons = word_phonemes[word]
if phons is None:
subwords = word_splits[word]
subphons = [word_phonemes[w] for w in subwords]
phons = ''.join(subphons)
return phons

def _clean_and_split_texts(self, texts: List[str], punc_set: set[str], punc_pattern: re.Pattern) -> tuple[List[List[str]], set[str]]:
split_text, cleaned_words = [], set()
for text in texts:
cleaned_text = ''.join(t for t in text if t.isalnum() or t in punc_set)
split = [s for s in re.split(punc_pattern, cleaned_text) if len(s) > 0]
split_text.append(split)
cleaned_words.update(split)
return split_text, cleaned_words

def convert_to_phonemes(self, texts: List[str], lang: str) -> List[str]:
"""
Converts a list of texts to phonemes using a phonemizer.
:param texts: List of texts to convert.
:param lang: Language of the texts.
:return: List of phonemes.
"""
split_text, cleaned_words = [], set()
punc_set = Punctuation.get_punc_set()
punc_pattern = Punctuation.get_punc_pattern()

# Step 1: Preprocess texts
split_text, cleaned_words = self._clean_and_split_texts(texts, punc_set, punc_pattern)

# Step 2: Collect dictionary phonemes for words and hyphenated words
for punct in punc_set:
self.phoneme_dict[punct] = punct
word_phonemes = {word: self.phoneme_dict.get(word) for word in cleaned_words}


# Step 3: If word is not in dictionary, split it into subwords
words_to_split = [w for w in cleaned_words if word_phonemes[w] is None]

word_splits = {
key: re.split(r'([-])', self.expand_acronym(word) if ModelConfig.EXPAND_ACRONYMS else word)
for key, word in zip(words_to_split, words_to_split)
}

subwords = {w for values in word_splits.values() for w in values if w not in word_phonemes}

for subword in subwords:
word_phonemes[subword] = self._get_dict_entry(word=subword, lang=lang, punc_set=punc_set)

# Step 4: Predict all subwords that are missing in the phoneme dict
words_to_predict = [word for word, phons in word_phonemes.items()
if phons is None and len(word_splits.get(word, [])) <= 1]
input_batch = [self.encode(word) for word in words_to_predict]
input_batch = self.pad_sequence_fixed(input_batch)


ort_inputs = {self.ort_session.get_inputs()[0].name: input_batch}
ort_outs = self.ort_session.run(None, ort_inputs)

ids = self.process_model_output(ort_outs)

# Step 5: Add predictions to the dictionary
for id, word in zip(ids, words_to_predict):
word_phonemes[word] = self.decode(id)

# Step 6: Get phonemes for each word in the text
phoneme_lists = []
for text in split_text:
text_phons = [
self._get_phonemes(word=word, word_phonemes=word_phonemes,
word_splits=word_splits)
for word in text
]
phoneme_lists.append(text_phons)

return [''.join(phoneme_list) for phoneme_list in phoneme_lists]

Loading

0 comments on commit 559a4c4

Please sign in to comment.