Skip to content

Commit

Permalink
[TTS] Add tutorials for FastPitch TTS speaker adaptation with adapters (
Browse files Browse the repository at this point in the history
#6431)

* Add tts adapter tutorial

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* Update main tutorial

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* Add tts adapter tutorial

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* Update main tutorial

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* Update tutorial

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* Follow comments

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Follow comments

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* Fix load .nemo error

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Support multi-speaker fine-tune

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Follow comments

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* Use .nemo

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* Follow Comments

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* Fix bug

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* Fix bug

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* Fix bug

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* Add precomputed speaker emb

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix space

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* Remove repeated argument

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* optional batch size

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

* Fix comments in notebook

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>

---------

Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and yaoyu-33 committed May 26, 2023
1 parent 50a28fb commit 4d3d58a
Show file tree
Hide file tree
Showing 8 changed files with 1,613 additions and 14 deletions.
11 changes: 6 additions & 5 deletions examples/tts/conf/fastpitch_align_44100_adapter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ model:
dropatt: 0.1
dropemb: 0.0
d_embed: ${model.symbols_embedding_dim}
condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ]
condition_types: [ "add", "layernorm" ] # options: [ "add", "concat", "layernorm" ]

output_fft:
_target_: nemo.collections.tts.modules.transformer.FFTransformerDecoder
Expand All @@ -221,12 +221,12 @@ model:
dropout: 0.1
dropatt: 0.1
dropemb: 0.0
condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ]
condition_types: [ "add", "layernorm" ] # options: [ "add", "concat", "layernorm" ]

alignment_module:
_target_: nemo.collections.tts.modules.aligner.AlignmentEncoder
n_text_channels: ${model.symbols_embedding_dim}
condition_types: [ "add" ] # options: [ "add", "cat" ]
condition_types: [ "add" ] # options: [ "add", "concat" ]

duration_predictor:
_target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor
Expand All @@ -235,7 +235,7 @@ model:
filter_size: 256
dropout: 0.1
n_layers: 2
condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ]
condition_types: [ "add", "layernorm" ] # options: [ "add", "concat", "layernorm" ]

pitch_predictor:
_target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor
Expand All @@ -244,10 +244,11 @@ model:
filter_size: 256
dropout: 0.1
n_layers: 2
condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ]
condition_types: [ "add", "layernorm" ] # options: [ "add", "concat", "layernorm" ]

speaker_encoder:
_target_: nemo.collections.tts.modules.submodules.SpeakerEncoder
precomputed_embedding_dim: null
lookup_module:
_target_: nemo.collections.tts.modules.submodules.SpeakerLookupTable
n_speakers: ???
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
if self.fastpitch.speaker_emb is not None:
self.export_config["num_speakers"] = cfg.n_speakers

# Adapter modules setup (from FastPitchAdapterModelMixin)
self.setup_adapters()

def _get_default_text_tokenizer_conf(self):
text_tokenizer: TextTokenizerConfig = TextTokenizerConfig()
return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer))
Expand Down
16 changes: 10 additions & 6 deletions nemo/collections/tts/modules/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def __init__(
self.learn_alignment = aligner is not None
self.use_duration_predictor = True
self.binarize = False

# TODO: combine self.speaker_emb with self.speaker_encoder
# cfg: remove `n_speakers`, create `speaker_encoder.lookup_module`
# state_dict: move `speaker_emb.weight` to `speaker_encoder.lookup_module.table.weight`
Expand Down Expand Up @@ -244,10 +243,10 @@ def output_types(self):
"energy_tgt": NeuralType(('B', 'T_audio'), RegressionValuesType()),
}

def get_speaker_embedding(self, speaker, reference_spec, reference_spec_lens):
def get_speaker_embedding(self, batch_size, speaker, reference_spec, reference_spec_lens):
"""spk_emb: Bx1xD"""
if self.speaker_encoder is not None:
spk_emb = self.speaker_encoder(speaker, reference_spec, reference_spec_lens).unsqueeze(1)
spk_emb = self.speaker_encoder(batch_size, speaker, reference_spec, reference_spec_lens).unsqueeze(1)
elif self.speaker_emb is not None:
if speaker is None:
raise ValueError('Please give speaker id to get lookup speaker embedding.')
Expand Down Expand Up @@ -281,7 +280,10 @@ def forward(

# Calculate speaker embedding
spk_emb = self.get_speaker_embedding(
speaker=speaker, reference_spec=reference_spec, reference_spec_lens=reference_spec_lens,
batch_size=text.shape[0],
speaker=speaker,
reference_spec=reference_spec,
reference_spec_lens=reference_spec_lens,
)

# Input FFT
Expand Down Expand Up @@ -379,10 +381,12 @@ def infer(
reference_spec=None,
reference_spec_lens=None,
):

# Calculate speaker embedding
spk_emb = self.get_speaker_embedding(
speaker=speaker, reference_spec=reference_spec, reference_spec_lens=reference_spec_lens,
batch_size=text.shape[0],
speaker=speaker,
reference_spec=reference_spec,
reference_spec_lens=reference_spec_lens,
)

# Input FFT
Expand Down
22 changes: 20 additions & 2 deletions nemo/collections/tts/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,18 +709,29 @@ class SpeakerEncoder represents speakers representation.
This module can combine GST (global style token) based speaker embeddings and lookup table speaker embeddings.
"""

def __init__(self, lookup_module=None, gst_module=None):
def __init__(self, lookup_module=None, gst_module=None, precomputed_embedding_dim=None):
"""
lookup_module: Torch module to get lookup based speaker embedding
gst_module: Neural module to get GST based speaker embedding
precomputed_embedding_dim: Give precomputed speaker embedding dimension to use precompute speaker embedding
"""
super(SpeakerEncoder, self).__init__()

# Multi-speaker embedding
self.lookup_module = lookup_module

# Reference speaker embedding
self.gst_module = gst_module

if precomputed_embedding_dim is not None:
self.precomputed_emb = torch.nn.Parameter(torch.empty(precomputed_embedding_dim))
else:
self.precomputed_emb = None

@property
def input_types(self):
return {
"batch_size": NeuralType(optional=True),
"speaker": NeuralType(('B'), Index(), optional=True),
"reference_spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True),
"reference_spec_lens": NeuralType(('B'), LengthsType(), optional=True),
Expand All @@ -732,9 +743,16 @@ def output_types(self):
"embs": NeuralType(('B', 'D'), EncodedRepresentation()),
}

def forward(self, speaker=None, reference_spec=None, reference_spec_lens=None):
def overwrite_precomputed_emb(self, emb):
self.precomputed_emb = torch.nn.Parameter(emb)

def forward(self, batch_size=None, speaker=None, reference_spec=None, reference_spec_lens=None):
embs = None

# Get Precomputed speaker embedding
if self.precomputed_emb is not None:
return self.precomputed_emb.unsqueeze(0).repeat(batch_size, 1)

# Get Lookup table speaker embedding
if self.lookup_module is not None and speaker is not None:
embs = self.lookup_module(speaker)
Expand Down
2 changes: 1 addition & 1 deletion scripts/dataset_processing/tts/extract_sup_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_pitch_stats(pitch_list):
def preprocess_ds_for_fastpitch_align(dataloader):
pitch_list = []
for batch in tqdm(dataloader, total=len(dataloader)):
audios, audio_lengths, tokens, tokens_lengths, align_prior_matrices, pitches, pitches_lengths = batch
audios, audio_lengths, tokens, tokens_lengths, align_prior_matrices, pitches, pitches_lengths, *_ = batch
pitch = pitches.squeeze(0)
pitch_list.append(pitch[pitch != 0])

Expand Down
11 changes: 11 additions & 0 deletions scripts/dataset_processing/tts/resynthesize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ def resynthesize_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
batch = to_device_recursive(batch, self.device)

mels, mel_lens = self.model.preprocessor(input_signal=batch["audio"], length=batch["audio_lens"])

reference_audio = batch.get("reference_audio", None)
reference_audio_len = batch.get("reference_audio_lens", None)
reference_spec, reference_spec_len = None, None
if reference_audio is not None:
reference_spec, reference_spec_len = self.model.preprocessor(
input_signal=reference_audio, length=reference_audio_len
)

outputs_tuple = self.model.forward(
text=batch["text"],
durs=None,
Expand All @@ -127,6 +136,8 @@ def resynthesize_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
attn_prior=batch.get("attn_prior"),
mel_lens=mel_lens,
input_lens=batch["text_lens"],
reference_spec=reference_spec,
reference_spec_lens=reference_spec_len,
)
names = self.model.fastpitch.output_types.keys()
return {"spec": mels, "mel_lens": mel_lens, **dict(zip(names, outputs_tuple))}
Expand Down
Loading

0 comments on commit 4d3d58a

Please sign in to comment.