Skip to content

Commit a0f65c6

Browse files
engiecatr9y9
authored andcommitted
Custom Dataset support + Gentle-based custom dataset preprocessing support (#78)
* Fixed typeerror (torch.index_select received an invalid combination of arguments) File "synthesis.py", line 137, in <module> model, text, p=replace_pronunciation_prob, speaker_id=speaker_id, fast=True) File "synthesis.py", line 66, in tts sequence, text_positions=text_positions, speaker_ids=speaker_ids) File "H:\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 325, in __call__ result = self.forward(*input, **kwargs) File "H:\Tensorflow_Study\git\deepvoice3_pytorch\deepvoice3_pytorch\__init__.py", line 79, in forward text_positions, frame_positions, input_lengths) File "H:\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 325, in __call__ result = self.forward(*input, **kwargs) File "H:\Tensorflow_Study\git\deepvoice3_pytorch\deepvoice3_pytorch\__init__.py", line 116, in forward text_sequences, lengths=input_lengths, speaker_embed=speaker_embed) File "H:\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 325, in __call__ result = self.forward(*input, **kwargs) File "H:\Tensorflow_Study\git\deepvoice3_pytorch\deepvoice3_pytorch\deepvoice3.py", line 75, in forward x = self.embed_tokens(text_sequences) <- change this to long! File "H:\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 325, in __call__ result = self.forward(*input, **kwargs) File "H:\envs\pytorch\lib\site-packages\torch\nn\modules\sparse.py", line 103, in forward self.scale_grad_by_freq, self.sparse File "H:\envs\pytorch\lib\site-packages\torch\nn\_functions\thnn\sparse.py", line 59, in forward output = torch.index_select(weight, 0, indices.view(-1)) TypeError: torch.index_select received an invalid combination of arguments - got (�[32;1mtorch.cuda.FloatTensor�[0m, �[32;1mint�[0m, �[31;1mtorch.cuda.IntTensor�[0m), but expected (torch.cuda.FloatTensor source, int dim, torch.cuda.LongTensor index) changed text_sequence to long, as required by torch.index_select. * Fixed Nonetype error in collect_features * requirements.txt fix * Memory Leakage bugfix + hparams change * Pre-PR modifications * Pre-PR modifications 2 * Pre-PR modifications 3 * Post-PR modification * remove requirements.txt * num_workers to 1 in train.py * Windows log filename bugfix * Revert "Windows log filename bugfix" This reverts commit 5214c24. * merge 2 * Windows Filename bugfix In windows, this causes WinError 123 * Cleanup before PR * JSON format Metadata support Supports JSON format for dataset creation. Ensures compatibility with http://github.com/carpedm20/multi-Speaker-tacotron-tensorflow * Web based Gentle aligner support * README change + gentle patch * .gitignore change gitignore change * Flake8 Fix * Post PR commit - Also fixed #5 #53 (comment) issue solved in PyTorch 0.4 * Post-PR 2 - .gitignore
1 parent 8716bb5 commit a0f65c6

8 files changed

+478
-8
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ log
1010
generated
1111
data
1212
text
13+
datasets
14+
testout
1315

1416
# Created by https://www.gitignore.io
1517

README.md

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ A notebook supposed to be executed on https://colab.research.google.com is avail
2323
- Convolutional sequence-to-sequence model with attention for text-to-speech synthesis
2424
- Multi-speaker and single speaker versions of DeepVoice3
2525
- Audio samples and pre-trained models
26-
- Preprocessor for [LJSpeech (en)](https://keithito.com/LJ-Speech-Dataset/), [JSUT (jp)](https://sites.google.com/site/shinnosuketakamichi/publication/jsut) and [VCTK](http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html) datasets
27-
- Language-dependent frontend text processor for English and Japanese
26+
- Preprocessor for [LJSpeech (en)](https://keithito.com/LJ-Speech-Dataset/), [JSUT (jp)](https://sites.google.com/site/shinnosuketakamichi/publication/jsut) and [VCTK](http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html) datasets, as well as [carpedm20/multi-speaker-tacotron-tensorflow](https://github.com/carpedm20/multi-Speaker-tacotron-tensorflow) compatible custom dataset (in JSON format)
27+
- Language-dependent frontend text processor for English and Japanese
2828

2929
### Samples
3030

@@ -104,7 +104,7 @@ python train.py --preset=presets/deepvoice3_ljspeech.json --data-root=./data/ljs
104104
- LJSpeech (en): https://keithito.com/LJ-Speech-Dataset/
105105
- VCTK (en): http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html
106106
- JSUT (jp): https://sites.google.com/site/shinnosuketakamichi/publication/jsut
107-
- NIKL (ko): http://www.korean.go.kr/front/board/boardStandardView.do?board_id=4&mn_id=17&b_seq=464
107+
- NIKL (ko) (**Need korean cellphone number to access it**): http://www.korean.go.kr/front/board/boardStandardView.do?board_id=4&mn_id=17&b_seq=464
108108

109109
### 1. Preprocessing
110110

@@ -130,6 +130,47 @@ python preprocess.py --preset=presets/deepvoice3_ljspeech.json ljspeech ~/data/L
130130

131131
When this is done, you will see extracted features (mel-spectrograms and linear spectrograms) in `./data/ljspeech`.
132132

133+
#### 1-1. Building custom dataset. (using json_meta)
134+
Building your own dataset, with metadata in JSON format (compatible with [carpedm20/multi-speaker-tacotron-tensorflow](https://github.com/carpedm20/multi-Speaker-tacotron-tensorflow)) is currently supported.
135+
Usage:
136+
137+
```
138+
python preprocess.py json_meta ${list-of-JSON-metadata-paths} ${out_dir} --preset=<json>
139+
```
140+
You may need to modify pre-existing preset JSON file, especially `n_speakers`. For english multispeaker, start with `presets/deepvoice3_vctk.json`.
141+
142+
Assuming you have dataset A (Speaker A) and dataset B (Speaker B), each described in the JSON metadata file `./datasets/datasetA/alignment.json` and `./datasets/datasetB/alignment.json`, then you can preprocess data by:
143+
144+
```
145+
python preprocess.py json_meta "./datasets/datasetA/alignment.json,./datasets/datasetB/alignment.json" "./datasets/processed_A+B" --preset=(path to preset json file)
146+
```
147+
148+
#### 1-2. Preprocessing custom english datasets with long silence. (Based on [vctk_preprocess](vctk_preprocess/))
149+
150+
Some dataset, especially automatically generated dataset may include long silence and undesirable leading/trailing noises, undermining the char-level seq2seq model.
151+
(e.g. VCTK, although this is covered in vctk_preprocess)
152+
153+
To deal with the problem, `gentle_web_align.py` will
154+
- **Prepare phoneme alignments for all utterances**
155+
- Cut silences during preprocessing
156+
157+
`gentle_web_align.py` uses [Gentle](https://github.com/lowerquality/gentle), a kaldi based speech-text alignment tool. This accesses web-served Gentle application, aligns given sound segments with transcripts and converts the result to HTK-style label files, to be processed in `preprocess.py`. Gentle can be run in Linux/Mac/Windows(via Docker).
158+
159+
Preliminary results show that while HTK/festival/merlin-based method in `vctk_preprocess/prepare_vctk_labels.py` works better on VCTK, Gentle is more stable with audio clips with ambient noise. (e.g. movie excerpts)
160+
161+
Usage:
162+
(Assuming Gentle is running at `localhost:8567` (Default when not specified))
163+
1. When sound file and transcript files are saved in separate folders. (e.g. sound files are at `datasetA/wavs` and transcripts are at `datasetA/txts`)
164+
```
165+
python gentle_web_align.py -w "datasetA/wavs/*.wav" -t "datasetA/txts/*.txt" --server_addr=localhost --port=8567
166+
```
167+
168+
2. When sound file and transcript files are saved in nested structure. (e.g. `datasetB/speakerN/blahblah.wav` and `datasetB/speakerN/blahblah.txt`)
169+
```
170+
python gentle_web_align.py --nested-directories="datasetB" --server_addr=localhost --port=8567
171+
```
172+
**Once you have phoneme alignment for each utterance, you can extract features by running `preprocess.py`**
173+
133174
### 2. Training
134175

135176
Usage:
@@ -141,7 +182,7 @@ python train.py --data-root=${data-root} --preset=<json> --hparams="parameters y
141182
Suppose you build a DeepVoice3-style model using LJSpeech dataset, then you can train your model by:
142183

143184
```
144-
python train.py --preset=presets/deepvoice3_ljspeech.json --data-root=./data/ljspeech/
185+
python train.py --preset=presets/deepvoice3_ljspeech.json --data-root=./data/ljspeech/
145186
```
146187

147188
Model checkpoints (.pth) and alignments (.png) are saved in `./checkpoints` directory per 10000 steps by default.
@@ -249,7 +290,9 @@ From my experience, it can get reasonable speech quality very quickly rather tha
249290
There are two important options used above:
250291

251292
- `--restore-parts=<N>`: It specifies where to load model parameters. The differences from the option `--checkpoint=<N>` are 1) `--restore-parts=<N>` ignores all invalid parameters, while `--checkpoint=<N>` doesn't. 2) `--restore-parts=<N>` tell trainer to start from 0-step, while `--checkpoint=<N>` tell trainer to continue from last step. `--checkpoint=<N>` should be ok if you are using exactly same model and continue to train, but it would be useful if you want to customize your model architecture and take advantages of pre-trained model.
252-
- `--speaker-id=<N>`: It specifies what speaker of data is used for training. This should only be specified if you are using multi-speaker dataset. As for VCTK, speaker id is automatically assigned incrementally (0, 1, ..., 107) according to the `speaker_info.txt` in the dataset.
293+
- `--speaker-id=<N>`: It specifies what speaker of data is used for training. This should only be specified if you are using multi-speaker dataset. As for VCTK, speaker id is automatically assigned incrementally (0, 1, ..., 107) according to the `speaker_info.txt` in the dataset.
294+
295+
If you are training multi-speaker model, speaker adaptation will only work **when `n_speakers` is identical**.
253296

254297
## Acknowledgements
255298

gentle_web_align.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Sat Apr 21 09:06:37 2018
4+
Phoneme alignment and conversion in HTK-style label file using Web-served Gentle
5+
This works on any type of english dataset.
6+
Unlike prepare_htk_alignments_vctk.py, this is Python3 and Windows(with Docker) compatible.
7+
Preliminary results show that gentle has better performance with noisy dataset
8+
(e.g. movie extracted audioclips)
9+
*This work was derived from vctk_preprocess/prepare_htk_alignments_vctk.py
10+
@author: engiecat(github)
11+
12+
usage:
13+
gentle_web_align.py (-w wav_pattern) (-t text_pattern) [options]
14+
gentle_web_align.py (--nested-directories=<main_directory>) [options]
15+
16+
options:
17+
-w <wav_pattern> --wav_pattern=<wav_pattern> Pattern of wav files to be aligned
18+
-t <txt_pattern> --txt_pattern=<txt_pattern> Pattern of txt transcript files to be aligned (same name required)
19+
--nested-directories=<main_directory> Process every wav/txt file in the subfolders of the given folder
20+
--server_addr=<server_addr> Server address that serves gentle. [default: localhost]
21+
--port=<port> Server port that serves gentle. [default: 8567]
22+
--max_unalign=<max_unalign> Maximum threshold for unalignment occurence (0.0 ~ 1.0) [default: 0.3]
23+
--skip-already-done Skips if there are preexisting .lab file
24+
-h --help show this help message and exit
25+
"""
26+
27+
from docopt import docopt
28+
from glob import glob
29+
from tqdm import tqdm
30+
import os.path
31+
import requests
32+
import numpy as np
33+
34+
def write_hts_label(labels, lab_path):
35+
lab = ""
36+
for s, e, l in labels:
37+
s, e = float(s) * 1e7, float(e) * 1e7
38+
s, e = int(s), int(e)
39+
lab += "{} {} {}\n".format(s, e, l)
40+
print(lab)
41+
with open(lab_path, "w", encoding='utf-8') as f:
42+
f.write(lab)
43+
44+
45+
def json2hts(data):
46+
emit_bos = False
47+
emit_eos = False
48+
49+
phone_start = 0
50+
phone_end = None
51+
labels = []
52+
failure_count = 0
53+
54+
for word in data["words"]:
55+
case = word["case"]
56+
if case != "success":
57+
failure_count += 1 # instead of failing everything,
58+
#raise RuntimeError("Alignment failed")
59+
continue
60+
start = float(word["start"])
61+
word_end = float(word["end"])
62+
63+
if not emit_bos:
64+
labels.append((phone_start, start, "silB"))
65+
emit_bos = True
66+
67+
phone_start = start
68+
phone_end = None
69+
for phone in word["phones"]:
70+
ph = str(phone["phone"][:-2])
71+
duration = float(phone["duration"])
72+
phone_end = phone_start + duration
73+
labels.append((phone_start, phone_end, ph))
74+
phone_start += duration
75+
assert np.allclose(phone_end, word_end)
76+
if not emit_eos:
77+
labels.append((phone_start, phone_end, "silE"))
78+
emit_eos = True
79+
unalign_ratio = float(failure_count) / len(data['words'])
80+
return unalign_ratio, labels
81+
82+
83+
def gentle_request(wav_path,txt_path, server_addr, port, debug=False):
84+
print('\n')
85+
response = None
86+
wav_name = os.path.basename(wav_path)
87+
txt_name = os.path.basename(txt_path)
88+
if os.path.splitext(wav_name)[0] != os.path.splitext(txt_name)[0]:
89+
print(' [!] wav name and transcript name does not match - exiting...')
90+
return response
91+
with open(txt_path, 'r', encoding='utf-8-sig') as txt_file:
92+
print('Transcript - '+''.join(txt_file.readlines()))
93+
with open(wav_path,'rb') as wav_file, open(txt_path, 'rb') as txt_file:
94+
params = (('async','false'),)
95+
files={'audio':(wav_name,wav_file),
96+
'transcript':(txt_name,txt_file),
97+
}
98+
server_path = 'http://'+server_addr+':'+str(port)+'/transcriptions'
99+
response = requests.post(server_path, params=params,files=files)
100+
if response.status_code != 200:
101+
print(' [!] External server({}) returned bad response({})'.format(server_path, response.status_code))
102+
if debug:
103+
print('Response')
104+
print(response.json())
105+
return response
106+
107+
if __name__ == '__main__':
108+
arguments = docopt(__doc__)
109+
server_addr = arguments['--server_addr']
110+
port = int(arguments['--port'])
111+
max_unalign = float(arguments['--max_unalign'])
112+
if arguments['--nested-directories'] is None:
113+
wav_paths = sorted(glob(arguments['--wav_pattern']))
114+
txt_paths = sorted(glob(arguments['--txt_pattern']))
115+
else:
116+
# if this is multi-foldered environment
117+
# (e.g. DATASET/speaker1/blahblah.wav)
118+
wav_paths=[]
119+
txt_paths=[]
120+
topdir = arguments['--nested-directories']
121+
subdirs = [f for f in os.listdir(topdir) if os.path.isdir(os.path.join(topdir, f))]
122+
for subdir in subdirs:
123+
wav_pattern_subdir = os.path.join(topdir, subdir, '*.wav')
124+
txt_pattern_subdir = os.path.join(topdir, subdir, '*.txt')
125+
wav_paths.extend(sorted(glob(wav_pattern_subdir)))
126+
txt_paths.extend(sorted(glob(txt_pattern_subdir)))
127+
128+
t = tqdm(range(len(wav_paths)))
129+
for idx in t:
130+
try:
131+
t.set_description("Align via Gentle")
132+
wav_path = wav_paths[idx]
133+
txt_path = txt_paths[idx]
134+
lab_path = os.path.splitext(wav_path)[0]+'.lab'
135+
if os.path.exists(lab_path) and arguments['--skip-already-done']:
136+
print('[!] skipping because of pre-existing .lab file - {}'.format(lab_path))
137+
continue
138+
res=gentle_request(wav_path,txt_path, server_addr, port)
139+
unalign_ratio, lab = json2hts(res.json())
140+
print('[*] Unaligned Ratio - {}'.format(unalign_ratio))
141+
if unalign_ratio > max_unalign:
142+
print('[!] skipping this due to bad alignment')
143+
continue
144+
write_hts_label(lab, lab_path)
145+
except:
146+
# if sth happens, skip it
147+
import traceback
148+
tb = traceback.format_exc()
149+
print('[!] ERROR while processing {}'.format(wav_paths[idx]))
150+
print('[!] StackTrace - ')
151+
print(tb)
152+
153+

hparams.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,14 @@
125125
# Forced garbage collection probability
126126
# Use only when MemoryError continues in Windows (Disabled by default)
127127
#gc_probability = 0.001,
128+
129+
# json_meta mode only
130+
# 0: "use all",
131+
# 1: "ignore only unmatched_alignment",
132+
# 2: "fully ignore recognition",
133+
ignore_recognition_level = 2,
134+
min_text=20, # when dealing with non-dedicated speech dataset(e.g. movie excerpts), setting min_text above 15 is desirable. Can be adjusted by dataset.
135+
process_only_htk_aligned = False, # if true, data without phoneme alignment file(.lab) will be ignored
128136
)
129137

130138

0 commit comments

Comments
 (0)