Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Add tutorials for FastPitch TTS speaker adaptation with adapters #6431

Merged
merged 35 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a2c1708
Add tts adapter tutorial
hsiehjackson Apr 14, 2023
9692b7d
Update main tutorial
hsiehjackson Apr 18, 2023
d7567f0
Add tts adapter tutorial
hsiehjackson Apr 14, 2023
76571f7
Update main tutorial
hsiehjackson Apr 18, 2023
3361eab
Merge branch 'tts_fastpitch_adapter_tutorial' of https://github.com/N…
hsiehjackson Apr 18, 2023
c33c188
Update tutorial
hsiehjackson Apr 18, 2023
e0d7f1f
Merge branch 'main' into tts_fastpitch_adapter_tutorial
hsiehjackson Apr 19, 2023
a57e12a
Follow comments
hsiehjackson Apr 24, 2023
2ad34f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2023
49f0f02
Follow comments
hsiehjackson Apr 24, 2023
ccc58bd
Merge branch 'tts_fastpitch_adapter_tutorial' of https://github.com/N…
hsiehjackson Apr 24, 2023
0631e5e
Fix load .nemo error
hsiehjackson Apr 26, 2023
d3e20d9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2023
ea85082
Support multi-speaker fine-tune
hsiehjackson Apr 26, 2023
55bd14b
Merge branch 'tts_fastpitch_adapter_tutorial' of https://github.com/N…
hsiehjackson Apr 26, 2023
7f7fa26
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2023
6dc0ff4
Follow comments
hsiehjackson Apr 27, 2023
94afa02
Merge branch 'tts_fastpitch_adapter_tutorial' of https://github.com/N…
hsiehjackson Apr 27, 2023
e50a603
Use .nemo
hsiehjackson Apr 27, 2023
f43a9f1
Merge branch 'main' into tts_fastpitch_adapter_tutorial
hsiehjackson Apr 27, 2023
b58c677
Follow Comments
hsiehjackson Apr 28, 2023
1ae1250
Fix bug
hsiehjackson Apr 28, 2023
0c2cbc0
Fix bug
hsiehjackson Apr 28, 2023
e71387e
Fix bug
hsiehjackson Apr 28, 2023
60555ab
Add precomputed speaker emb
hsiehjackson May 2, 2023
e9ba7f6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 2, 2023
4795fe1
Fix space
hsiehjackson May 2, 2023
87a430e
Fix space
hsiehjackson May 2, 2023
4b375d9
Remove repeated argument
hsiehjackson May 2, 2023
908aa67
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 2, 2023
f402847
Merge branch 'main' into tts_fastpitch_adapter_tutorial
hsiehjackson May 2, 2023
4e75a6e
optional batch size
hsiehjackson May 2, 2023
1cb4727
Merge branch 'tts_fastpitch_adapter_tutorial' of https://github.com/N…
hsiehjackson May 2, 2023
3053a52
Fix comments in notebook
hsiehjackson May 2, 2023
9037e0c
Merge branch 'main' into tts_fastpitch_adapter_tutorial
hsiehjackson May 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions examples/tts/conf/fastpitch_align_44100_adapter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ model:
dropatt: 0.1
dropemb: 0.0
d_embed: ${model.symbols_embedding_dim}
condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ]
condition_types: [ "add", "layernorm" ] # options: [ "add", "concat", "layernorm" ]

output_fft:
_target_: nemo.collections.tts.modules.transformer.FFTransformerDecoder
Expand All @@ -221,12 +221,12 @@ model:
dropout: 0.1
dropatt: 0.1
dropemb: 0.0
condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ]
condition_types: [ "add", "layernorm" ] # options: [ "add", "concat", "layernorm" ]

alignment_module:
_target_: nemo.collections.tts.modules.aligner.AlignmentEncoder
n_text_channels: ${model.symbols_embedding_dim}
condition_types: [ "add" ] # options: [ "add", "cat" ]
condition_types: [ "add" ] # options: [ "add", "concat" ]

duration_predictor:
_target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor
Expand All @@ -235,7 +235,7 @@ model:
filter_size: 256
dropout: 0.1
n_layers: 2
condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ]
condition_types: [ "add", "layernorm" ] # options: [ "add", "concat", "layernorm" ]

pitch_predictor:
_target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor
Expand All @@ -244,10 +244,11 @@ model:
filter_size: 256
dropout: 0.1
n_layers: 2
condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ]
condition_types: [ "add", "layernorm" ] # options: [ "add", "concat", "layernorm" ]

speaker_encoder:
_target_: nemo.collections.tts.modules.submodules.SpeakerEncoder
precomputed_embedding_dim: null
lookup_module:
_target_: nemo.collections.tts.modules.submodules.SpeakerLookupTable
n_speakers: ???
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
if self.fastpitch.speaker_emb is not None:
self.export_config["num_speakers"] = cfg.n_speakers

# Adapter modules setup (from FastPitchAdapterModelMixin)
self.setup_adapters()

def _get_default_text_tokenizer_conf(self):
text_tokenizer: TextTokenizerConfig = TextTokenizerConfig()
return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer))
Expand Down
16 changes: 10 additions & 6 deletions nemo/collections/tts/modules/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def __init__(
self.learn_alignment = aligner is not None
self.use_duration_predictor = True
self.binarize = False

# TODO: combine self.speaker_emb with self.speaker_encoder
# cfg: remove `n_speakers`, create `speaker_encoder.lookup_module`
# state_dict: move `speaker_emb.weight` to `speaker_encoder.lookup_module.table.weight`
Expand Down Expand Up @@ -244,10 +243,10 @@ def output_types(self):
"energy_tgt": NeuralType(('B', 'T_audio'), RegressionValuesType()),
}

def get_speaker_embedding(self, speaker, reference_spec, reference_spec_lens):
def get_speaker_embedding(self, batch_size, speaker, reference_spec, reference_spec_lens):
"""spk_emb: Bx1xD"""
if self.speaker_encoder is not None:
spk_emb = self.speaker_encoder(speaker, reference_spec, reference_spec_lens).unsqueeze(1)
spk_emb = self.speaker_encoder(batch_size, speaker, reference_spec, reference_spec_lens).unsqueeze(1)
elif self.speaker_emb is not None:
if speaker is None:
raise ValueError('Please give speaker id to get lookup speaker embedding.')
Expand Down Expand Up @@ -281,7 +280,10 @@ def forward(

# Calculate speaker embedding
spk_emb = self.get_speaker_embedding(
speaker=speaker, reference_spec=reference_spec, reference_spec_lens=reference_spec_lens,
batch_size=text.shape[0],
speaker=speaker,
reference_spec=reference_spec,
reference_spec_lens=reference_spec_lens,
)

# Input FFT
Expand Down Expand Up @@ -379,10 +381,12 @@ def infer(
reference_spec=None,
reference_spec_lens=None,
):

# Calculate speaker embedding
spk_emb = self.get_speaker_embedding(
speaker=speaker, reference_spec=reference_spec, reference_spec_lens=reference_spec_lens,
batch_size=text.shape[0],
speaker=speaker,
reference_spec=reference_spec,
reference_spec_lens=reference_spec_lens,
)

# Input FFT
Expand Down
22 changes: 20 additions & 2 deletions nemo/collections/tts/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,18 +709,29 @@ class SpeakerEncoder represents speakers representation.
This module can combine GST (global style token) based speaker embeddings and lookup table speaker embeddings.
"""

def __init__(self, lookup_module=None, gst_module=None):
def __init__(self, lookup_module=None, gst_module=None, precomputed_embedding_dim=None):
"""
lookup_module: Torch module to get lookup based speaker embedding
gst_module: Neural module to get GST based speaker embedding
precomputed_embedding_dim: Give precomputed speaker embedding dimension to use precompute speaker embedding
"""
super(SpeakerEncoder, self).__init__()

# Multi-speaker embedding
self.lookup_module = lookup_module

# Reference speaker embedding
self.gst_module = gst_module

if precomputed_embedding_dim is not None:
self.precomputed_emb = torch.nn.Parameter(torch.empty(precomputed_embedding_dim))
else:
self.precomputed_emb = None

@property
def input_types(self):
return {
"batch_size": NeuralType(),
"speaker": NeuralType(('B'), Index(), optional=True),
"reference_spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True),
"reference_spec_lens": NeuralType(('B'), LengthsType(), optional=True),
Expand All @@ -732,9 +743,16 @@ def output_types(self):
"embs": NeuralType(('B', 'D'), EncodedRepresentation()),
}

def forward(self, speaker=None, reference_spec=None, reference_spec_lens=None):
def overwrite_precomputed_emb(self, emb):
self.precomputed_emb = torch.nn.Parameter(emb)

def forward(self, batch_size, speaker=None, reference_spec=None, reference_spec_lens=None):
hsiehjackson marked this conversation as resolved.
Show resolved Hide resolved
embs = None

# Get Precomputed speaker embedding
if self.precomputed_emb is not None:
return self.precomputed_emb.unsqueeze(0).repeat(batch_size, 1)

# Get Lookup table speaker embedding
if self.lookup_module is not None and speaker is not None:
embs = self.lookup_module(speaker)
Expand Down
2 changes: 1 addition & 1 deletion scripts/dataset_processing/tts/extract_sup_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_pitch_stats(pitch_list):
def preprocess_ds_for_fastpitch_align(dataloader):
pitch_list = []
for batch in tqdm(dataloader, total=len(dataloader)):
audios, audio_lengths, tokens, tokens_lengths, align_prior_matrices, pitches, pitches_lengths = batch
audios, audio_lengths, tokens, tokens_lengths, align_prior_matrices, pitches, pitches_lengths, *_ = batch
pitch = pitches.squeeze(0)
pitch_list.append(pitch[pitch != 0])

Expand Down
11 changes: 11 additions & 0 deletions scripts/dataset_processing/tts/resynthesize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ def resynthesize_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
batch = to_device_recursive(batch, self.device)

mels, mel_lens = self.model.preprocessor(input_signal=batch["audio"], length=batch["audio_lens"])

reference_audio = batch.get("reference_audio", None)
reference_audio_len = batch.get("reference_audio_lens", None)
rlangman marked this conversation as resolved.
Show resolved Hide resolved
reference_spec, reference_spec_len = None, None
if reference_audio is not None:
reference_spec, reference_spec_len = self.model.preprocessor(
input_signal=reference_audio, length=reference_audio_len
)

outputs_tuple = self.model.forward(
text=batch["text"],
durs=None,
Expand All @@ -127,6 +136,8 @@ def resynthesize_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
attn_prior=batch.get("attn_prior"),
mel_lens=mel_lens,
input_lens=batch["text_lens"],
reference_spec=reference_spec,
reference_spec_lens=reference_spec_len,
)
names = self.model.fastpitch.output_types.keys()
return {"spec": mels, "mel_lens": mel_lens, **dict(zip(names, outputs_tuple))}
Expand Down
Loading