Skip to content

Commit

Permalink
Modify residual transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
ezerhouni committed Mar 4, 2024
1 parent bbb2f35 commit 7daac3d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
6 changes: 2 additions & 4 deletions egs/ljspeech/TTS/vits2/duration_discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def forward_probability(
x = self.pre_out_conv_1(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_1(x)
x = self.dropout(x)

x = self.pre_out_conv_2(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_2(x)
x = self.dropout(x)

x = x * x_mask
x = x.transpose(1, 2)
output_prob = self.output_layer(x)
Expand All @@ -123,12 +123,10 @@ def forward(
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.dropout(x)

x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.dropout(x)

output_probs = []
for dur in [dur_r, dur_hat]:
Expand Down
14 changes: 6 additions & 8 deletions egs/ljspeech/TTS/vits2/residual_coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def __init__(
self.use_only_mean = use_only_mean

self.pre_transformer = Transformer(
self.half_channels,
hidden_channels,
num_heads=heads_transformer,
num_layers=layers_transformer,
cnn_module_kernel=kernel_size_transformer,
Expand Down Expand Up @@ -362,14 +362,12 @@ def forward(
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
xa, xb = x.split(x.size(1) // 2, dim=1)
h_mask = make_pad_mask(torch.sum(x_mask, dim=[1, 2]).type(torch.int64))

x_trans_mask = make_pad_mask(torch.sum(x_mask, dim=[1, 2]).type(torch.int64))
xa_ = self.pre_transformer(
(xa * x_mask).transpose(1, 2), x_trans_mask
).transpose(1, 2)
xa_ = xa + xa_

h = self.input_conv(xa_) * x_mask
h = self.input_conv(xa) * x_mask
h = h + self.pre_transformer((h * x_mask).transpose(1, 2), h_mask).transpose(
1, 2
) # vits2 residual connection
h = self.encoder(h, x_mask, g=g)

stats = self.proj(h) * x_mask
Expand Down

0 comments on commit 7daac3d

Please sign in to comment.