diff --git a/configs/_base_/schedulers/warmup_cosine.py b/configs/_base_/schedulers/warmup_cosine.py index 219c9627..1aca4f06 100644 --- a/configs/_base_/schedulers/warmup_cosine.py +++ b/configs/_base_/schedulers/warmup_cosine.py @@ -12,7 +12,7 @@ optimizer = dict( type="AdamW", - lr=1.0, + lr=2e-5, weight_decay=1e-2, betas=(0.9, 0.98), eps=1e-9, diff --git a/configs/_base_/trainers/base.py b/configs/_base_/trainers/base.py index 93c818f1..cde3b1c7 100644 --- a/configs/_base_/trainers/base.py +++ b/configs/_base_/trainers/base.py @@ -15,7 +15,7 @@ check_val_every_n_epoch=None, max_steps=2_000_000, # Warning: If you are training the model with fs2 (and see nan), you should either use bf16 or fp32 - precision="32", + precision="bf16-mixed", accumulate_grad_batches=1, callbacks=[ ModelCheckpoint( diff --git a/configs/tts_baseline.py b/configs/tts_baseline.py index 96ef8c96..a054e04e 100644 --- a/configs/tts_baseline.py +++ b/configs/tts_baseline.py @@ -89,11 +89,9 @@ pretrained=True, ), duration_predictor=dict( - type="TransformerEncoder", + type="NaiveProjectionEncoder", input_size=bert_dim, output_size=1, - hidden_size=bert_dim, - num_layers=1, ), vocoder=dict( type="ADaMoSHiFiGANV1", @@ -125,6 +123,9 @@ train=dict( batch_size=8, ), + valid=dict( + batch_size=8, + ), ) preprocessing = dict( diff --git a/fish_diffusion/archs/diffsinger/grad_tts.py b/fish_diffusion/archs/diffsinger/grad_tts.py index e7bd146c..588f96fb 100644 --- a/fish_diffusion/archs/diffsinger/grad_tts.py +++ b/fish_diffusion/archs/diffsinger/grad_tts.py @@ -55,17 +55,7 @@ def forward_features( phones2mel=None, energy=None, ): - src_masks = ( - self.get_mask_from_lengths(contents_lens, contents_max_len) - if contents_lens is not None - else None - ) - - mel_masks = ( - self.get_mask_from_lengths(mel_lens, mel_max_len) - if mel_lens is not None - else None - ) + src_masks = self.get_mask_from_lengths(contents_lens, contents_max_len) features = self.text_encoder.bert.embeddings.word_embeddings(contents)[ :, 0, :, : @@ -93,15 +83,15 @@ def forward_features( ) # Predict durations - durations = self.duration_predictor(features, src_masks_float) - log_durations = durations[:, 0, 0] + log_durations = self.duration_predictor(features[:, 0, :])[..., 0] duration_loss = F.smooth_l1_loss(log_durations, torch.log(mel_lens.float())) if self.training is False: mel_lens = torch.round(torch.exp(torch.clamp(log_durations, 1, 8))).long() mel_lens = torch.clamp(mel_lens, 10, 2048) mel_max_len = torch.max(mel_lens).item() - mel_masks = self.get_mask_from_lengths(mel_lens, mel_max_len) + + mel_masks = self.get_mask_from_lengths(mel_lens, mel_max_len) return dict( features=features, diff --git a/fish_diffusion/modules/convnext.py b/fish_diffusion/modules/convnext.py index 11028c91..680b654a 100644 --- a/fish_diffusion/modules/convnext.py +++ b/fish_diffusion/modules/convnext.py @@ -108,6 +108,8 @@ def __init__( self.diffusion_step_projection = nn.Conv1d(dim, dim, 1) self.register_buffer("positional_embedding", self.get_embedding(dim)) + self.position_scale_query = nn.Parameter(torch.ones(1)) + self.position_scale_key = nn.Parameter(torch.ones(1)) def get_embedding(self, embedding_dim, num_embeddings=4096): half_dim = embedding_dim // 2 @@ -130,8 +132,13 @@ def forward(self, x, condition, diffusion_step, x_masks=None, cond_masks=None): x = x.transpose(1, 2) condition = condition.transpose(1, 2) - x = x + self.positional_embedding[: x.size(1)][None] - condition = condition + self.positional_embedding[: condition.size(1)][None] + # self.get_embedding(dim) + x = x + self.positional_embedding[: x.size(1)][None] * self.position_scale_query + condition = ( + condition + + self.positional_embedding[: condition.size(1)][None] + * self.position_scale_key + ) return ( super()