diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index 9e1b1c4097..b6e9ac8a14 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -241,7 +241,7 @@ def __init__( ) self.duration_predictor = DurationPredictor( - self.args.hidden_channels + self.embedded_speaker_dim, + self.args.hidden_channels, self.args.duration_predictor_hidden_channels, self.args.duration_predictor_kernel_size, self.args.duration_predictor_dropout_p, @@ -249,7 +249,7 @@ def __init__( if self.args.use_pitch: self.pitch_predictor = DurationPredictor( - self.args.hidden_channels + self.embedded_speaker_dim, + self.args.hidden_channels, self.args.pitch_predictor_hidden_channels, self.args.pitch_predictor_kernel_size, self.args.pitch_predictor_dropout_p, @@ -263,7 +263,7 @@ def __init__( if self.args.use_energy: self.energy_predictor = DurationPredictor( - self.args.hidden_channels + self.embedded_speaker_dim, + self.args.hidden_channels, self.args.energy_predictor_hidden_channels, self.args.energy_predictor_kernel_size, self.args.energy_predictor_dropout_p, @@ -299,7 +299,8 @@ def init_multispeaker(self, config: Coqpit): if config.use_d_vector_file: self.embedded_speaker_dim = config.d_vector_dim if self.args.d_vector_dim != self.args.hidden_channels: - self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) + #self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) + self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels) # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: print(" > Init speaker_embedding layer.") @@ -403,10 +404,13 @@ def _forward_encoder( # [B, T, C] x_emb = self.emb(x) # encoder pass - o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) + #o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) + o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g) # speaker conditioning # TODO: try different ways of conditioning - if g is not None: + if g is not None: + if hasattr(self, "proj_g"): + g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1) o_en = o_en + g return o_en, x_mask, g, x_emb