From 43c93d8a5578dadf4f56f21eb9cf0f0870e60fb7 Mon Sep 17 00:00:00 2001 From: Ryan Langman Date: Tue, 19 Sep 2023 10:56:27 -0700 Subject: [PATCH] [TTS] Fix audio codec type checks (#7373) * [TTS] Fix audio codec type checks Signed-off-by: Ryan * [TTS] Fix audio codec tests Signed-off-by: Ryan --------- Signed-off-by: Ryan --- .../tts/losses/audio_codec_loss.py | 6 +- nemo/collections/tts/models/audio_codec.py | 6 +- .../tts/modules/audio_codec_modules.py | 28 ++++---- .../tts/modules/encodec_modules.py | 64 +++++++++++-------- .../tts/modules/test_audio_codec_modules.py | 6 +- 5 files changed, 61 insertions(+), 49 deletions(-) diff --git a/nemo/collections/tts/losses/audio_codec_loss.py b/nemo/collections/tts/losses/audio_codec_loss.py index bde96fadb4c2..8819282f07bd 100644 --- a/nemo/collections/tts/losses/audio_codec_loss.py +++ b/nemo/collections/tts/losses/audio_codec_loss.py @@ -40,8 +40,8 @@ def __init__(self, loss_fn, loss_scale: float = 1.0): @property def input_types(self): return { - "target": NeuralType(('B', 'D', 'T'), RegressionValuesType()), "predicted": NeuralType(('B', 'D', 'T'), PredictionsType()), + "target": NeuralType(('B', 'D', 'T'), RegressionValuesType()), "target_len": NeuralType(tuple('B'), LengthsType()), } @@ -97,7 +97,7 @@ def input_types(self): @property def output_types(self): return { - "loss": [NeuralType(elements_type=LossType())], + "loss": NeuralType(elements_type=LossType()), } @typecheck() @@ -146,7 +146,7 @@ def input_types(self): @property def output_types(self): return { - "loss": [NeuralType(elements_type=LossType())], + "loss": NeuralType(elements_type=LossType()), } @typecheck() diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 6414fa20e52d..63140b77f2b5 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -484,9 +484,11 @@ def configure_optimizers(self): sched_config = optim_config.pop("sched", None) OmegaConf.set_struct(optim_config, True) - gen_params = itertools.chain(self.audio_encoder.parameters(), self.audio_decoder.parameters()) - disc_params = self.discriminator.parameters() + vq_params = self.vector_quantizer.parameters() if self.vector_quantizer else [] + gen_params = itertools.chain(self.audio_encoder.parameters(), self.audio_decoder.parameters(), vq_params) optim_g = instantiate(optim_config, params=gen_params) + + disc_params = self.discriminator.parameters() optim_d = instantiate(optim_config, params=disc_params) if sched_config is None: diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index aaf4fb0a7f21..90c53b1f4337 100644 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Optional, Tuple +from typing import Optional, Tuple -import torch import torch.nn as nn -from einops import rearrange from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor +from nemo.core.classes.common import typecheck from nemo.core.classes.module import NeuralModule -from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, LengthsType, VoidType +from nemo.core.neural_types.elements import LengthsType, VoidType from nemo.core.neural_types.neural_type import NeuralType @@ -64,21 +63,22 @@ def __init__( def input_types(self): return { "inputs": NeuralType(('B', 'C', 'T'), VoidType()), - "lengths": NeuralType(tuple('B'), LengthsType()), + "input_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { - "out": [NeuralType(('B', 'C', 'T'), VoidType())], + "out": NeuralType(('B', 'C', 'T'), VoidType()), } def remove_weight_norm(self): nn.utils.remove_weight_norm(self.conv) - def forward(self, inputs, lengths): + @typecheck() + def forward(self, inputs, input_len): out = self.conv(inputs) - out = mask_sequence_tensor(out, lengths) + out = mask_sequence_tensor(out, input_len) return out @@ -101,21 +101,22 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride def input_types(self): return { "inputs": NeuralType(('B', 'C', 'T'), VoidType()), - "lengths": NeuralType(tuple('B'), LengthsType()), + "input_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { - "out": [NeuralType(('B', 'C', 'T'), VoidType())], + "out": NeuralType(('B', 'C', 'T'), VoidType()), } def remove_weight_norm(self): nn.utils.remove_weight_norm(self.conv) - def forward(self, inputs, lengths): + @typecheck() + def forward(self, inputs, input_len): out = self.conv(inputs) - out = mask_sequence_tensor(out, lengths) + out = mask_sequence_tensor(out, input_len) return out @@ -151,11 +152,12 @@ def input_types(self): @property def output_types(self): return { - "out": [NeuralType(('B', 'C', 'H', 'T'), VoidType())], + "out": NeuralType(('B', 'C', 'H', 'T'), VoidType()), } def remove_weight_norm(self): nn.utils.remove_weight_norm(self.conv) + @typecheck() def forward(self, inputs): return self.conv(inputs) diff --git a/nemo/collections/tts/modules/encodec_modules.py b/nemo/collections/tts/modules/encodec_modules.py index 031b2001e5ca..b05187ccb74b 100644 --- a/nemo/collections/tts/modules/encodec_modules.py +++ b/nemo/collections/tts/modules/encodec_modules.py @@ -72,13 +72,13 @@ def __init__(self, channels: int): def input_types(self): return { "inputs": NeuralType(('B', 'C', 'T_input'), VoidType()), - "lengths": NeuralType(tuple('B'), LengthsType()), + "input_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { - "out": [NeuralType(('B', 'C', 'T_out'), VoidType())], + "out": NeuralType(('B', 'C', 'T_out'), VoidType()), } def remove_weight_norm(self): @@ -86,14 +86,15 @@ def remove_weight_norm(self): self.res_conv1.remove_weight_norm() self.res_conv2.remove_weight_norm() - def forward(self, inputs, lengths): + @typecheck() + def forward(self, inputs, input_len): res = self.activation(inputs) - res = self.res_conv1(res, lengths) + res = self.res_conv1(inputs=res, input_len=input_len) res = self.activation(res) - res = self.res_conv2(res, lengths) + res = self.res_conv2(inputs=res, input_len=input_len) - out = self.pre_conv(inputs, lengths) + res - out = mask_sequence_tensor(out, lengths) + out = self.pre_conv(inputs=inputs, input_len=input_len) + res + out = mask_sequence_tensor(out, input_len) return out @@ -112,20 +113,21 @@ def __init__(self, dim: int, num_layers: int, rnn_type: str = "lstm", use_skip: def input_types(self): return { "inputs": NeuralType(('B', 'C', 'T'), VoidType()), - "lengths": NeuralType(tuple('B'), LengthsType()), + "input_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { - "out": [NeuralType(('B', 'C', 'T'), VoidType())], + "out": NeuralType(('B', 'C', 'T'), VoidType()), } - def forward(self, inputs, lengths): + @typecheck() + def forward(self, inputs, input_len): inputs = rearrange(inputs, "B C T -> B T C") packed_inputs = nn.utils.rnn.pack_padded_sequence( - inputs, lengths=lengths.cpu(), batch_first=True, enforce_sorted=False + inputs, lengths=input_len.cpu(), batch_first=True, enforce_sorted=False ) packed_out, _ = self.rnn(packed_inputs) out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True) @@ -183,15 +185,15 @@ def __init__( @property def input_types(self): return { - "audio": NeuralType(('B', 'C', 'T_audio'), AudioSignal()), + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { - "encoded": [NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation())], - "encoded_len": [NeuralType(tuple('B'), LengthsType())], + "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), } def remove_weight_norm(self): @@ -201,26 +203,27 @@ def remove_weight_norm(self): for down_sample_conv in self.down_sample_conv_layers: down_sample_conv.remove_weight_norm() + @typecheck() def forward(self, audio, audio_len): encoded_len = audio_len audio = rearrange(audio, "B T -> B 1 T") # [B, C, T_audio] - out = self.pre_conv(audio, encoded_len) + out = self.pre_conv(inputs=audio, input_len=encoded_len) for res_block, down_sample_conv, down_sample_rate in zip( self.res_blocks, self.down_sample_conv_layers, self.down_sample_rates ): # [B, C, T] - out = res_block(out, encoded_len) + out = res_block(inputs=out, input_len=encoded_len) out = self.activation(out) encoded_len = encoded_len // down_sample_rate # [B, 2 * C, T / down_sample_rate] - out = down_sample_conv(out, encoded_len) + out = down_sample_conv(inputs=out, input_len=encoded_len) - out = self.rnn(out, encoded_len) + out = self.rnn(inputs=out, input_len=encoded_len) out = self.activation(out) # [B, encoded_dim, T_encoded] - encoded = self.post_conv(out, encoded_len) + encoded = self.post_conv(inputs=out, input_len=encoded_len) return encoded, encoded_len @@ -274,7 +277,7 @@ def input_types(self): @property def output_types(self): return { - "audio": NeuralType(('B', 'C', 'T_audio'), AudioSignal()), + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), } @@ -285,23 +288,24 @@ def remove_weight_norm(self): for res_block in self.res_blocks: res_block.remove_weight_norm() + @typecheck() def forward(self, inputs, input_len): audio_len = input_len # [B, C, T_encoded] - out = self.pre_conv(inputs, audio_len) - out = self.rnn(out, audio_len) + out = self.pre_conv(inputs=inputs, input_len=audio_len) + out = self.rnn(inputs=out, input_len=audio_len) for res_block, up_sample_conv, up_sample_rate in zip( self.res_blocks, self.up_sample_conv_layers, self.up_sample_rates ): audio_len = audio_len * up_sample_rate out = self.activation(out) # [B, C / 2, T * up_sample_rate] - out = up_sample_conv(out, audio_len) - out = res_block(out, audio_len) + out = up_sample_conv(inputs=out, input_len=audio_len) + out = res_block(inputs=out, input_len=audio_len) out = self.activation(out) # [B, 1, T_audio] - out = self.post_conv(out, audio_len) + out = self.post_conv(inputs=out, input_len=audio_len) audio = self.out_activation(out) audio = rearrange(audio, "B 1 T -> B T") return audio, audio_len @@ -356,6 +360,7 @@ def output_types(self): "fmap": [NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())], } + @typecheck() def forward(self, audio): fmap = [] @@ -363,11 +368,11 @@ def forward(self, audio): out = self.stft(audio) for conv in self.conv_layers: # [batch, filters, T_spec, fft // 2**i] - out = conv(out) + out = conv(inputs=out) out = self.activation(out) fmap.append(out) # [batch, 1, T_spec, fft // 8] - scores = self.conv_post(out) + scores = self.conv_post(inputs=out) fmap.append(scores) scores = rearrange(scores, "B 1 T C -> B C T") @@ -382,7 +387,7 @@ def __init__(self, resolutions): @property def input_types(self): return { - "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_real": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()), } @@ -395,6 +400,7 @@ def output_types(self): "fmaps_gen": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], } + @typecheck() def forward(self, audio_real, audio_gen): scores_real = [] scores_gen = [] @@ -627,6 +633,7 @@ def output_types(self): "indices": NeuralType(('B', 'T'), Index()), } + @typecheck() def forward(self, inputs, input_len): input_flat = rearrange(inputs, "B T D -> (B T) D") self._init_codes(input_flat) @@ -746,6 +753,7 @@ def output_types(self): "commit_loss": NeuralType((), LossType()), } + @typecheck() def forward(self, inputs: Tensor, input_len: Tensor) -> Tuple[Tensor, Tensor, float]: commit_loss = 0.0 residual = rearrange(inputs, "B D T -> B T D") diff --git a/tests/collections/tts/modules/test_audio_codec_modules.py b/tests/collections/tts/modules/test_audio_codec_modules.py index 948b1220f39c..4650a6508edd 100644 --- a/tests/collections/tts/modules/test_audio_codec_modules.py +++ b/tests/collections/tts/modules/test_audio_codec_modules.py @@ -40,7 +40,7 @@ def test_conv1d(self): lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32) conv = Conv1dNorm(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size) - out = conv(inputs, lengths) + out = conv(inputs=inputs, input_len=lengths) assert out.shape == (self.batch_size, self.out_channels, self.max_len) assert torch.all(out[0, :, : self.len1] != 0.0) @@ -66,7 +66,7 @@ def test_conv1d_downsample(self): stride=stride, padding=padding, ) - out = conv(inputs, lengths) + out = conv(inputs=inputs, input_len=lengths) assert out.shape == (self.batch_size, self.out_channels, out_len) assert torch.all(out[0, :, :out_len_1] != 0.0) @@ -87,7 +87,7 @@ def test_conv1d_transpose_upsample(self): conv = ConvTranspose1dNorm( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=stride ) - out = conv(inputs, lengths) + out = conv(inputs=inputs, input_len=lengths) assert out.shape == (self.batch_size, self.out_channels, out_len) assert torch.all(out[0, :, :out_len_1] != 0.0)