Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use piper_phonemize as text tokenizer in ljspeech recipe #1511

Merged
merged 10 commits into from
Feb 29, 2024
64 changes: 9 additions & 55 deletions egs/ljspeech/TTS/local/prepare_token_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,88 +17,42 @@


"""
This file reads the texts in given manifest and generates the file that maps tokens to IDs.
This file generates the file that maps tokens to IDs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copyright should be 2023-2024.

"""

import argparse
import logging
from pathlib import Path
from typing import Dict

from lhotse import load_manifest
from piper_phonemize import get_espeak_map


def get_args():
parser = argparse.ArgumentParser()

parser.add_argument(
"--manifest-file",
type=Path,
default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
help="Path to the manifest file",
)

parser.add_argument(
"--tokens",
type=Path,
default=Path("data/tokens.txt"),
help="Path to the tokens",
help="Path to the dict that maps the text tokens to IDs",
)

return parser.parse_args()


def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.

Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.

Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
def get_token2id(filename: Path) -> Dict[str, int]:
"""Get a dict that maps token to IDs, and save it to the given filename."""
all_tokens = get_espeak_map()
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")


def get_token2id(manifest_file: Path) -> Dict[str, int]:
"""Return a dict that maps token to IDs."""
extra_tokens = [
"<blk>", # 0 for blank
"<sos/eos>", # 1 for sos and eos symbols.
"<unk>", # 2 for OOV
]
all_tokens = set()

cut_set = load_manifest(manifest_file)

for cut in cut_set:
# Each cut only contain one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
for t in cut.tokens:
all_tokens.add(t)

all_tokens = extra_tokens + list(all_tokens)

token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
return token2id
for token, token_id in all_tokens.items():
f.write(f"{token} {token_id[0]}\n")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you sort by token_id in filename?

That is, sort the second column from 0 to vocab_size-1 in ascending order?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.



if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)

args = get_args()
manifest_file = Path(args.manifest_file)
out_file = Path(args.tokens)

token2id = get_token2id(manifest_file)
write_mapping(out_file, token2id)
get_token2id(out_file)
9 changes: 6 additions & 3 deletions egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import logging
from pathlib import Path

import g2p_en
import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest
from piper_phonemize import phonemize_espeak


def prepare_tokens_ljspeech():
Expand All @@ -35,7 +35,6 @@ def prepare_tokens_ljspeech():
partition = "all"

cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
g2p = g2p_en.G2p()

new_cuts = []
for cut in cut_set:
Expand All @@ -45,7 +44,11 @@ def prepare_tokens_ljspeech():
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
cut.tokens = g2p(text)
tokens_list = phonemize_espeak(text, "en-us")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At line 42

assert len(cut.supervisions) == 1, len(cut.supervisions)

Please use

assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)

It is helpful to print the problematic cut on error.

tokens = []
for t in tokens_list:
tokens.extend(t)
cut.tokens = tokens
new_cuts.append(cut)

new_cut_set = CutSet.from_cuts(new_cuts)
Expand Down
16 changes: 10 additions & 6 deletions egs/ljspeech/TTS/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
cd vits/monotonic_align
python setup.py build_ext --inplace
cd ../../
else
else
log "monotonic_align lib already built"
fi
fi
Expand Down Expand Up @@ -80,6 +80,11 @@ fi

if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for LJSpeech"
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
./local/prepare_tokens_ljspeech.py
mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \
Expand Down Expand Up @@ -113,13 +118,12 @@ fi

if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate token file"
# We assume you have installed g2p_en and espnet_tts_frontend.
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
# - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/tokens.txt ]; then
./local/prepare_token_file.py \
--manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \
--tokens data/tokens.txt
./local/prepare_token_file.py --tokens data/tokens.txt
fi
fi
65 changes: 47 additions & 18 deletions egs/ljspeech/TTS/vits/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from typing import Dict, List

import g2p_en
import tacotron_cleaner.cleaners
from piper_phonemize import phonemize_espeak
from utils import intersperse


Expand All @@ -38,21 +38,34 @@ def __init__(self, tokens: str):
id = int(info[0])
else:
token, id = info[0], int(info[1])
assert token not in self.token2id, token
self.token2id[token] = id

self.blank_id = self.token2id["<blk>"]
self.oov_id = self.token2id["<unk>"]
self.vocab_size = len(self.token2id)
# Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md
self.pad_id = self.token2id["_"] # padding
self.sos_id = self.token2id["^"] # beginning of an utterance (bos)
self.eos_id = self.token2id["$"] # end of an utterance (eos)
self.space_id = self.token2id[" "] # word separator (whitespace)

self.g2p = g2p_en.G2p()
self.vocab_size = len(self.token2id)

def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
def texts_to_token_ids(
self,
texts: List[str],
intersperse_blank: bool = True,
add_sos: bool = False,
add_eos: bool = False,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please give the return value a type hint.

"""
Args:
texts:
A list of transcripts.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
add_sos:
Whether to add sos token at the start.
add_eos:
Whether to add eos token at the end.

Returns:
Return a list of token id list [utterance][token_id]
Expand All @@ -63,30 +76,44 @@ def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
tokens = self.g2p(text)
tokens_list = phonemize_espeak(text, "en-us")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please pass en-us as an argument to this function.
You can use

lang: str = 'en-us`

as the last argument for this function

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot!

tokens = []
for t in tokens_list:
tokens.extend(t)

token_ids = []
for t in tokens:
if t in self.token2id:
token_ids.append(self.token2id[t])
else:
token_ids.append(self.oov_id)
assert t in self.token2id, t
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert t in self.token2id, t
if t not in self.token2id:
logging.warning(f'Skip oov {t}')
continue

We just skip OOVs instead of throwing an assertion error, which
may kill the process.

token_ids.append(self.token2id[t])

if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
token_ids = intersperse(token_ids, self.pad_id)
if add_sos:
token_ids = [self.sos_id] + token_ids
if add_eos:
token_ids = token_ids + [self.eos_id]

token_ids_list.append(token_ids)

return token_ids_list

def tokens_to_token_ids(
self, tokens_list: List[str], intersperse_blank: bool = True
self,
tokens_list: List[str],
intersperse_blank: bool = True,
add_sos: bool = False,
add_eos: bool = False,
):
"""
Args:
tokens_list:
A list of token list, each corresponding to one utterance.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
add_sos:
Whether to add sos token at the start.
add_eos:
Whether to add eos token at the end.

Returns:
Return a list of token id list [utterance][token_id]
Expand All @@ -96,13 +123,15 @@ def tokens_to_token_ids(
for tokens in tokens_list:
token_ids = []
for t in tokens:
if t in self.token2id:
token_ids.append(self.token2id[t])
else:
token_ids.append(self.oov_id)
assert t in self.token2id, t
token_ids.append(self.token2id[t])

if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
token_ids = intersperse(token_ids, self.pad_id)
if add_sos:
token_ids = [self.sos_id] + token_ids
if add_eos:
token_ids = token_ids + [self.eos_id]

token_ids_list.append(token_ids)

Expand Down
Loading